未验证 提交 b0edda4d 编写于 作者: D Double_V 提交者: GitHub

kunlun add op (#27890)

* add stack pool2d roi_align xpu op,test=kunlun

* error message opt, test=kunlun

* add xpu unittest,test=kunlun

* skip check grad,test=kunlun

* fix boostget , test=kunlun
上级 82eb486e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/pool_op.h"
#include <unordered_map>
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
xpu::Pooling_t XPUPoolingType(const std::string& pooltype, bool exclusive,
bool is_test) {
if (pooltype == "max") {
return xpu::Pooling_t::MAX_WITHOUT_INDEX;
} else if (pooltype == "avg") {
if (exclusive) {
return xpu::Pooling_t::AVG_WITHOUT_PAD;
} else {
return xpu::Pooling_t::AVG_WITH_PAD;
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Pool op only supports 2D and 3D input."));
}
}
template <typename DeviceContext, typename T>
class PoolXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
std::string pooling_type = context.Attr<std::string>("pooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool is_test = context.Attr<bool>("is_test");
bool adaptive = context.Attr<bool>("adaptive");
PADDLE_ENFORCE_EQ(!adaptive, true,
platform::errors::InvalidArgument(
"XPU does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(ksize.size(), 2,
platform::errors::InvalidArgument(
"XPU only support 2 dimension pooling!"));
int* index_data = nullptr;
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
const int c = in_x->dims()[0] * in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
const int win_h = ksize[0];
const int win_w = ksize[1];
const int stride_h = strides[0];
const int stride_w = strides[1];
const int pad_up = paddings[0];
const int pad_down = paddings[0];
const int pad_left = paddings[1];
const int pad_right = paddings[1];
const float* input = in_x->data<float>();
out->mutable_data<T>(context.GetPlace());
float* output = out->data<float>();
xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, is_test);
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::pooling_forward<float, float>(
dev_ctx.x_context(), input, output, index_data, pool_type, c, in_h,
in_w, pad_left, pad_right, pad_up, pad_down, win_h, win_w, stride_h,
stride_w, out_h, out_w);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument("pool2d XPU kernel error!"));
}
};
template <typename DeviceContext, typename T>
class PoolGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = context.Attr<std::string>("pooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
const int* index_data = nullptr;
PADDLE_ENFORCE_EQ(!adaptive, true,
platform::errors::InvalidArgument(
"XPU does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(ksize.size(), 2,
platform::errors::InvalidArgument(
"XPU only support 2 dimension pooling!"));
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
if (!in_x_grad) {
return;
}
const int c = in_x->dims()[0] * in_x->dims()[1];
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
const int win_h = ksize[0];
const int win_w = ksize[1];
const int stride_h = strides[0];
const int stride_w = strides[1];
const int pad_up = paddings[0];
const int pad_down = paddings[0];
const int pad_left = paddings[1];
const int pad_right = paddings[1];
const float* input = in_x->data<float>();
const float* output = out->data<float>();
const float* output_grad = out_grad->data<float>();
in_x_grad->mutable_data<T>(context.GetPlace());
float* input_grad = in_x_grad->data<float>();
xpu::Pooling_t pool_type = XPUPoolingType(pooling_type, exclusive, false);
auto& dev_ctx = context.template device_context<DeviceContext>();
// Need to init memory in the first place
const int zero = 0;
int r =
xpu::memset(dev_ctx.x_context(), reinterpret_cast<void**>(input_grad),
zero, in_x_grad->numel() * sizeof(float));
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"There are pool2d grad XPU kernel error raised!"));
r = xpu::pooling_backward(dev_ctx.x_context(), input, output, index_data,
output_grad, input_grad, pool_type, c, in_h, in_w,
pad_left, pad_right, pad_up, pad_down, win_h,
win_w, stride_h, stride_w, out_h, out_w);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"There are pool2d grad XPU kernel error raised!"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
pool2d, ops::PoolXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
pool2d_grad,
ops::PoolGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/roi_align_op.h"
#include <memory>
#include <string>
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class XPUROIAlignOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto* rois = ctx.Input<framework::LoDTensor>("ROIs");
auto* out = ctx.Output<framework::Tensor>("Out");
auto pooled_height = ctx.Attr<int>("pooled_height");
auto pooled_width = ctx.Attr<int>("pooled_width");
auto spatial_scale = ctx.Attr<float>("spatial_scale");
auto sampling_ratio = ctx.Attr<int>("sampling_ratio");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto in_dims = in->dims();
int batch_size = in_dims[0];
int channels = in_dims[1];
int height = in_dims[2];
int width = in_dims[3];
int rois_num = rois->dims()[0];
const T* input_data = in->data<T>();
auto rois_lod = rois->lod().back();
int rois_batch_size = rois_lod.size() - 1;
PADDLE_ENFORCE_EQ(
rois_batch_size, batch_size,
platform::errors::InvalidArgument(
"The rois_batch_size and imgs batch_size must be the same."));
int rois_num_with_lod = rois_lod[rois_batch_size];
PADDLE_ENFORCE_EQ(rois_num, rois_num_with_lod,
platform::errors::InvalidArgument(
"The rois_num from input and lod must be the same."));
T* output_data = out->mutable_data<T>(ctx.GetPlace());
const T* rois_data = rois->data<T>();
for (int n = 0; n < rois_batch_size; n++) {
int cur_batch_rois_num = rois_lod[n + 1] - rois_lod[n];
if (cur_batch_rois_num != 0) {
int r = xpu::roi_align(
dev_ctx.x_context(), input_data + n * channels * height * width,
rois_data + rois_lod[n] * 4, cur_batch_rois_num, channels, height,
width, pooled_height, pooled_width, sampling_ratio, spatial_scale,
output_data +
rois_lod[n] * channels * pooled_height * pooled_width);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument("roi_align XPU kernel error!"));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
roi_align,
ops::XPUROIAlignOpKernel<paddle::platform::XPUDeviceContext, float>);
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/stack_op.h"
#ifdef PADDLE_WITH_XPU
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename DeviceContext, typename T>
class StackXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto x = ctx.MultiInput<Tensor>("X");
auto* y = ctx.Output<Tensor>("Y");
int axis = ctx.Attr<int>("axis");
if (axis < 0) {
axis += (x[0]->dims().size() + 1);
}
int n = static_cast<int>(x.size());
PADDLE_ENFORCE_LE(n, 24,
platform::errors::InvalidArgument(
"XPU only surpport at most 24 tensors for now"));
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
int pre = 1, post = 1;
auto& dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) {
pre *= dim[i];
}
for (auto i = axis; i < dim.size(); ++i) {
post *= dim[i];
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
void* x_datas_host = std::malloc(n * sizeof(void*));
void* x_datas_device = nullptr;
PADDLE_ENFORCE(xpu_malloc(reinterpret_cast<void**>(&x_datas_device),
n * sizeof(void*)) == XPU_SUCCESS);
for (auto i = 0; i < n; ++i) {
((const void**)x_datas_host)[i] = x[i]->data<T>();
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
x_datas_device, platform::CPUPlace(), x_datas_host,
n * sizeof(void*));
int r = xpu::stack_forward<float>(dev_ctx.x_context(), pre, post, n,
x_datas_device, y_data);
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::InvalidArgument(
"There are stack XPU kernel error raised!"));
dev_ctx.Wait();
std::free(x_datas_host);
xpu_free(x_datas_device);
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(stack,
ops::StackXPUKernel<plat::XPUDeviceContext, float>);
#endif
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import paddle.fluid.core as core
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def max_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
data_type=np.float64):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = np.max((i * strides[0] - paddings[0], 0))
r_end = np.min((i * strides[0] + ksize[0] - paddings[0], H))
c_start = np.max((j * strides[1] - paddings[1], 0))
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
return out
def avg_pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
data_type=np.float64):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + 2 * paddings[0] + strides[0] - 1
) // strides[0] + 1 if ceil_mode else (
H - ksize[0] + 2 * paddings[0]) // strides[0] + 1
W_out = (W - ksize[1] + 2 * paddings[1] + strides[1] - 1
) // strides[1] + 1 if ceil_mode else (
W - ksize[1] + 2 * paddings[1]) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out))
for i in range(H_out):
for j in range(W_out):
if adaptive:
r_start = adaptive_start_index(i, H, ksize[0])
r_end = adaptive_end_index(i, H, ksize[0])
c_start = adaptive_start_index(j, W, ksize[1])
c_end = adaptive_end_index(j, W, ksize[1])
else:
r_start = i * strides[0] - paddings[0]
r_end = i * strides[0] + ksize[0] - paddings[0]
c_start = j * strides[1] - paddings[1]
c_end = j * strides[1] + ksize[1] - paddings[1]
field_size = (r_end - r_start) * (c_end - c_start)
r_start = np.max((r_start, 0))
r_end = np.min((r_end, H))
c_start = np.max((c_start, 0))
c_end = np.min((c_end, W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
if (exclusive or adaptive):
field_size = (r_end - r_start) * (c_end - c_start)
if data_type == np.int8 or data_type == np.uint8:
out[:, :, i, j] = (np.rint(
np.sum(x_masked, axis=(2, 3)) /
field_size)).astype(data_type)
else:
out[:, :, i, j] = (np.sum(x_masked, axis=(2, 3)) /
field_size).astype(data_type)
return out
def pool2D_forward_naive(x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
data_format='NCHW',
pool_type="max",
padding_algorithm="EXPLICIT"):
# update paddings
def _get_padding_with_SAME(input_shape, pool_size, pool_stride):
padding = []
for input_size, filter_size, stride_size in zip(input_shape, pool_size,
pool_stride):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max((
(out_size - 1) * stride_size + filter_size - input_size, 0))
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
if isinstance(padding_algorithm, str):
padding_algorithm = padding_algorithm.upper()
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError("Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." %
str(padding_algorithm))
if padding_algorithm == "VALID":
paddings = [0, 0, 0, 0]
if ceil_mode != False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode)"
" must be False. "
"Received ceil_mode: True.")
elif padding_algorithm == "SAME":
input_data_shape = []
if data_format == "NCHW":
input_data_shape = x.shape[2:4]
elif data_format == "NHWC":
input_data_shape = x.shape[1:3]
paddings = _get_padding_with_SAME(input_data_shape, ksize, strides)
assert len(paddings) == 2 or len(paddings) == 4
is_sys = True if len(paddings) == 2 else False
N = x.shape[0]
C, H, W = [x.shape[1], x.shape[2], x.shape[3]] if data_format == 'NCHW' \
else [x.shape[3], x.shape[1], x.shape[2]]
if global_pool == 1:
ksize = [H, W]
paddings = [0 for _ in range(len(paddings))]
pad_h_up = paddings[0] if is_sys else paddings[0]
pad_h_down = paddings[0] if is_sys else paddings[1]
pad_w_left = paddings[1] if is_sys else paddings[2]
pad_w_right = paddings[1] if is_sys else paddings[3]
if adaptive:
H_out, W_out = ksize
else:
H_out = (H - ksize[0] + pad_h_up + pad_h_down + strides[0] - 1) // strides[0] + 1 \
if ceil_mode else (H - ksize[0] + pad_h_up + pad_h_down) // strides[0] + 1
W_out = (W - ksize[1] + pad_w_left + pad_w_right + strides[1] - 1) // strides[1] + 1 \
if ceil_mode else (W - ksize[1] + pad_w_left + pad_w_right) // strides[1] + 1
out = np.zeros((N, C, H_out, W_out)) if data_format=='NCHW' \
else np.zeros((N, H_out, W_out, C))
for i in range(H_out):
if adaptive:
in_h_start = adaptive_start_index(i, H, ksize[0])
in_h_end = adaptive_end_index(i, H, ksize[0])
else:
in_h_start = np.max((i * strides[0] - pad_h_up, 0))
in_h_end = np.min((i * strides[0] + ksize[0] - pad_h_up, H))
for j in range(W_out):
if adaptive:
in_w_start = adaptive_start_index(j, W, ksize[1])
in_w_end = adaptive_end_index(j, W, ksize[1])
else:
in_h_start = i * strides[0] - pad_h_up
in_w_start = j * strides[1] - pad_w_left
in_h_end = i * strides[0] + ksize[0] - pad_h_up
in_w_end = j * strides[1] + ksize[1] - pad_w_left
field_size = (in_h_end - in_h_start) * (in_w_end - in_w_start)
in_h_start = np.max((in_h_start, 0))
in_w_start = np.max((in_w_start, 0))
in_h_end = np.min((in_h_end, H))
in_w_end = np.min((in_w_end, W))
if data_format == 'NCHW':
x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end]
if pool_type == 'avg':
if (exclusive or adaptive):
field_size = (in_h_end - in_h_start) * (
in_w_end - in_w_start)
# if (exclusive or adaptive) else (ksize[0] * ksize[1])
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
elif pool_type == 'max':
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
elif data_format == 'NHWC':
x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :]
if pool_type == 'avg':
if (exclusive or adaptive):
field_size = (in_h_end - in_h_start) * (
in_w_end - in_w_start)
out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size
elif pool_type == 'max':
out[:, i, j, :] = np.max(x_masked, axis=(1, 2))
return out
class TestPool2D_Op(OpTest):
def setUp(self):
self.op_type = "pool2d"
self.use_cudnn = False
self.init_kernel_type()
self.use_mkldnn = False
self.init_data_type()
self.init_test_case()
self.padding_algorithm = "EXPLICIT"
self.init_paddings()
self.init_global_pool()
self.init_kernel_type()
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
self.init_adaptive()
self.init_data_format()
self.init_shape()
input = np.random.random(self.shape).astype(self.dtype)
output = pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive, self.adaptive, self.data_format,
self.pool_type, self.padding_algorithm).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'ceil_mode': self.ceil_mode,
'data_format': self.data_format,
'exclusive': self.exclusive,
'adaptive': self.adaptive,
"padding_algorithm": self.padding_algorithm,
}
self.outputs = {'Out': output}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
def init_data_format(self):
self.data_format = "NCHW"
def init_shape(self):
self.shape = [2, 3, 5, 5]
def init_test_case(self):
self.ksize = [3, 3]
self.strides = [1, 1]
def init_paddings(self):
self.paddings = [0, 0]
self.padding_algorithm = "EXPLICIT"
def init_kernel_type(self):
self.use_cudnn = False
def init_data_type(self):
self.dtype = np.float64
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = True
def init_ceil_mode(self):
self.ceil_mode = False
def init_exclusive(self):
self.exclusive = True
def init_adaptive(self):
self.adaptive = False
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import math
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
@skip_check_grad_ci(reason="There is no grad kernel for roi_align_xpu kernel.")
class TestROIAlignOp(OpTest):
def set_data(self):
self.init_test_case()
self.make_rois()
self.calc_roi_align()
self.inputs = {
'X': self.x,
'ROIs': (self.rois[:, 1:5], self.rois_lod),
}
self.attrs = {
'spatial_scale': self.spatial_scale,
'pooled_height': self.pooled_height,
'pooled_width': self.pooled_width,
'sampling_ratio': self.sampling_ratio
}
self.outputs = {'Out': self.out_data}
def init_test_case(self):
self.batch_size = 3
self.channels = 3
self.height = 8
self.width = 6
# n, c, h, w
self.x_dim = (self.batch_size, self.channels, self.height, self.width)
self.spatial_scale = 1.0 / 2.0
self.pooled_height = 2
self.pooled_width = 2
self.sampling_ratio = -1
self.x = np.random.random(self.x_dim).astype('float64')
def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w,
bin_size_h, bin_size_w):
count = roi_bin_grid_h * roi_bin_grid_w
bilinear_pos = np.zeros(
[self.channels, self.pooled_height, self.pooled_width, count, 4],
np.float64)
bilinear_w = np.zeros(
[self.pooled_height, self.pooled_width, count, 4], np.float64)
for ph in range(self.pooled_width):
for pw in range(self.pooled_height):
c = 0
for iy in range(roi_bin_grid_h):
y = roi_ymin + ph * bin_size_h + (iy + 0.5) * \
bin_size_h / roi_bin_grid_h
for ix in range(roi_bin_grid_w):
x = roi_xmin + pw * bin_size_w + (ix + 0.5) * \
bin_size_w / roi_bin_grid_w
if y < -1.0 or y > self.height or \
x < -1.0 or x > self.width:
continue
if y <= 0:
y = 0
if x <= 0:
x = 0
y_low = int(y)
x_low = int(x)
if y_low >= self.height - 1:
y = y_high = y_low = self.height - 1
else:
y_high = y_low + 1
if x_low >= self.width - 1:
x = x_high = x_low = self.width - 1
else:
x_high = x_low + 1
ly = y - y_low
lx = x - x_low
hy = 1 - ly
hx = 1 - lx
for ch in range(self.channels):
bilinear_pos[ch, ph, pw, c, 0] = x_i[ch, y_low,
x_low]
bilinear_pos[ch, ph, pw, c, 1] = x_i[ch, y_low,
x_high]
bilinear_pos[ch, ph, pw, c, 2] = x_i[ch, y_high,
x_low]
bilinear_pos[ch, ph, pw, c, 3] = x_i[ch, y_high,
x_high]
bilinear_w[ph, pw, c, 0] = hy * hx
bilinear_w[ph, pw, c, 1] = hy * lx
bilinear_w[ph, pw, c, 2] = ly * hx
bilinear_w[ph, pw, c, 3] = ly * lx
c = c + 1
return bilinear_pos, bilinear_w
def calc_roi_align(self):
self.out_data = np.zeros(
(self.rois_num, self.channels, self.pooled_height,
self.pooled_width)).astype('float64')
for i in range(self.rois_num):
roi = self.rois[i]
roi_batch_id = int(roi[0])
x_i = self.x[roi_batch_id]
roi_xmin = roi[1] * self.spatial_scale
roi_ymin = roi[2] * self.spatial_scale
roi_xmax = roi[3] * self.spatial_scale
roi_ymax = roi[4] * self.spatial_scale
roi_width = max(roi_xmax - roi_xmin, 1)
roi_height = max(roi_ymax - roi_ymin, 1)
bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width)
roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_height / self.pooled_height)
roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \
math.ceil(roi_width / self.pooled_width)
count = int(roi_bin_grid_h * roi_bin_grid_w)
pre_size = count * self.pooled_width * self.pooled_height
bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin,
int(roi_bin_grid_h),
int(roi_bin_grid_w),
bin_size_h, bin_size_w)
for ch in range(self.channels):
align_per_bin = (bilinear_pos[ch] * bilinear_w).sum(axis=-1)
output_val = align_per_bin.mean(axis=-1)
self.out_data[i, ch, :, :] = output_val
def make_rois(self):
rois = []
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x1 = np.random.random_integers(
0, self.width // self.spatial_scale - self.pooled_width)
y1 = np.random.random_integers(
0, self.height // self.spatial_scale - self.pooled_height)
x2 = np.random.random_integers(x1 + self.pooled_width,
self.width // self.spatial_scale)
y2 = np.random.random_integers(
y1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x1, y1, x2, y2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float64")
def setUp(self):
self.op_type = "roi_align"
self.set_data()
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
@skip_check_grad_ci(reason="There is no grad kernel for stack_xpu op.")
class TestStackOpBase(OpTest):
def initDefaultParameters(self):
self.num_inputs = 4
self.input_dim = (5, 6, 7)
self.axis = 0
self.dtype = 'float64'
def initParameters(self):
pass
def get_x_names(self):
x_names = []
for i in range(self.num_inputs):
x_names.append('x{}'.format(i))
return x_names
def setUp(self):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.x = []
for i in range(self.num_inputs):
self.x.append(
np.random.random(size=self.input_dim).astype(self.dtype))
tmp = []
x_names = self.get_x_names()
for i in range(self.num_inputs):
tmp.append((x_names[i], self.x[i]))
self.inputs = {'X': tmp}
self.outputs = {'Y': np.stack(self.x, axis=self.axis)}
self.attrs = {'axis': self.axis}
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
class TestStackOp1(TestStackOpBase):
def initParameters(self):
self.num_inputs = 16
class TestStackOp2(TestStackOpBase):
def initParameters(self):
self.num_inputs = 20
class TestStackOp3(TestStackOpBase):
def initParameters(self):
self.axis = -1
class TestStackOp4(TestStackOpBase):
def initParameters(self):
self.axis = -4
class TestStackOp5(TestStackOpBase):
def initParameters(self):
self.axis = 1
class TestStackOp6(TestStackOpBase):
def initParameters(self):
self.axis = 3
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册