From 99ea0a9c3b02e67fdb8f902c2e3df64761f94db3 Mon Sep 17 00:00:00 2001 From: zhaoying9105 Date: Tue, 28 Jun 2022 16:51:19 +0800 Subject: [PATCH] [MLU]: add roi_align and roi_align_grad kernel (#43757) --- paddle/fluid/operators/mlu/mlu_baseop.cc | 70 ++++ paddle/fluid/operators/mlu/mlu_baseop.h | 24 ++ paddle/fluid/operators/roi_align_op_mlu.cc | 299 ++++++++++++++++++ .../unittests/mlu/test_roi_align_op_mlu.py | 222 +++++++++++++ 4 files changed, 615 insertions(+) create mode 100644 paddle/fluid/operators/roi_align_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_roi_align_op_mlu.py diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 77bc446243..1b40a4e74f 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -4557,5 +4557,75 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { diff_input)); } +/* static */ void MLUCnnl::RoiAlign(const ExecutionContext& ctx, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const float spatial_scale, + const bool aligned, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t boxes_desc, + const void* boxes, + const cnnlTensorDescriptor_t output_desc, + void* output) { + cnnlRoiAlignDescriptor_t roialign_desc; + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateRoiAlignDescriptor(&roialign_desc)); + const int pool_mode = 1; // average pooling mode + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRoiAlignDescriptor_v2(roialign_desc, + pooled_height, + pooled_width, + sampling_ratio, + spatial_scale, + pool_mode, + aligned)); + + cnnlHandle_t handle = GetHandleFromCTX(ctx); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlRoiAlign_v2(handle, + roialign_desc, + input_desc, + input, + boxes_desc, + boxes, + output_desc, + output, + nullptr, + nullptr, + nullptr, + nullptr)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroyRoiAlignDescriptor(roialign_desc)); +} + +/* static */ void MLUCnnl::RoiAlignBackward( + const ExecutionContext& ctx, + const int sampling_ratio, + const float spatial_scale, + const bool aligned, + const cnnlTensorDescriptor_t grads_desc, + const void* grads, + const cnnlTensorDescriptor_t boxes_desc, + const void* boxes, + const cnnlTensorDescriptor_t grads_image_desc, + void* grads_image) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + const int pool_mode = 1; // average pooling mode + PADDLE_ENFORCE_MLU_SUCCESS(cnnlRoiAlignBackward_v2(handle, + grads_desc, + grads, + boxes_desc, + boxes, + nullptr, + nullptr, + nullptr, + nullptr, + spatial_scale, + sampling_ratio, + aligned, + pool_mode, + grads_image_desc, + grads_image)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 6882fc17f0..8dcdd33e34 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1860,6 +1860,30 @@ class MLUCnnl { const void* pos_weight, const cnnlTensorDescriptor_t diff_input_desc, void* diff_input); + + static void RoiAlign(const ExecutionContext& ctx, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const float spatial_scale, + const bool aligned, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t boxes_desc, + const void* boxes, + const cnnlTensorDescriptor_t output_desc, + void* output); + + static void RoiAlignBackward(const ExecutionContext& ctx, + const int sampling_ratio, + const float spatial_scale, + const bool aligned, + const cnnlTensorDescriptor_t grads_desc, + const void* grads, + const cnnlTensorDescriptor_t boxes_desc, + const void* boxes, + const cnnlTensorDescriptor_t grads_image_desc, + void* grads_image); }; template diff --git a/paddle/fluid/operators/roi_align_op_mlu.cc b/paddle/fluid/operators/roi_align_op_mlu.cc new file mode 100644 index 0000000000..c6f17b56cd --- /dev/null +++ b/paddle/fluid/operators/roi_align_op_mlu.cc @@ -0,0 +1,299 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class ROIAlignOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in = ctx.Input("X"); + auto* rois = ctx.Input("ROIs"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + out->set_layout(framework::DataLayout::kNHWC); + + auto pooled_height = ctx.Attr("pooled_height"); + auto pooled_width = ctx.Attr("pooled_width"); + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + auto aligned = ctx.Attr("aligned"); + const auto& in_dims = in->dims(); + int batch_size = in_dims[0]; + int rois_num = rois->dims()[0]; + + if (rois_num == 0) return; + auto cplace = platform::CPUPlace(); + std::vector roi_batch_id_list(rois_num); + int rois_batch_size = 0; + if (ctx.HasInput("RoisNum")) { + auto* rois_num_t = ctx.Input("RoisNum"); + rois_batch_size = rois_num_t->numel(); + PADDLE_ENFORCE_EQ( + rois_batch_size, + batch_size, + platform::errors::InvalidArgument( + "The batch size of rois and the batch size of images " + " must be the same. But received the batch size of rois is %d, " + "and the batch size of images is %d", + rois_batch_size, + batch_size)); + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, + rois_num_list.data(), + ctx.GetPlace(), + rois_num_t->data(), + sizeof(int) * rois_batch_size, + nullptr /*stream*/); + int last_idx = 0; + for (int i = 0; i < rois_batch_size; i++) { + int end_idx = last_idx + rois_num_list[i]; + for (int j = last_idx; j < end_idx; j++) { + roi_batch_id_list[j] = i; + } + last_idx = end_idx; + } + } else { + auto lod = rois->lod(); + PADDLE_ENFORCE_EQ(lod.empty(), + false, + platform::errors::InvalidArgument( + "Input(ROIs) Tensor of ROIAlignOp " + "does not contain LoD information.")); + auto rois_lod = lod.back(); + rois_batch_size = rois_lod.size() - 1; + PADDLE_ENFORCE_EQ(rois_batch_size, + batch_size, + platform::errors::InvalidArgument( + "The rois_batch_size and imgs " + "batch_size must be the same. But received " + "rois_batch_size = %d, " + "batch_size = %d", + rois_batch_size, + batch_size)); + int rois_num_with_lod = rois_lod[rois_batch_size]; + PADDLE_ENFORCE_EQ( + rois_num, + rois_num_with_lod, + platform::errors::InvalidArgument( + "The actual number of rois and the number of rois " + "provided from Input(RoIsLoD) in RoIAlign must be the same." + " But received actual number of rois is %d, and the number " + "of rois from RoIsLoD is %d", + rois_num, + rois_num_with_lod)); + for (int i = 0; i < rois_batch_size; i++) { + int start_idx = rois_lod[i]; + int end_idx = rois_lod[i + 1]; + for (int j = start_idx; j < end_idx; j++) { + roi_batch_id_list[j] = i; + } + } + } + + // only support float32 for now + Tensor rois_cpu(framework::TransToPhiDataType(VT::FP32)); + rois_cpu.Resize({rois_num, 4}); + rois_cpu.mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + framework::TensorCopy(*rois, cplace, dev_ctx, &rois_cpu); + dev_ctx.Wait(); + T* rois_cpu_ptr = rois_cpu.mutable_data(platform::CPUPlace()); + + // boxes; [batch_idx, x1, y1, x2, y2] + Tensor boxes_cpu(framework::TransToPhiDataType(VT::FP32)); + Tensor boxes_mlu(framework::TransToPhiDataType(VT::FP32)); + boxes_cpu.Resize({rois_num, 5}); + boxes_mlu.Resize({rois_num, 5}); + T* boxes_cpu_ptr = boxes_cpu.mutable_data(platform::CPUPlace()); + boxes_mlu.mutable_data(ctx.GetPlace()); + for (int i = 0; i < rois_num; ++i) { + boxes_cpu_ptr[i * 5 + 0] = static_cast(roi_batch_id_list[i]); + boxes_cpu_ptr[i * 5 + 1] = rois_cpu_ptr[i * 4 + 0]; + boxes_cpu_ptr[i * 5 + 2] = rois_cpu_ptr[i * 4 + 1]; + boxes_cpu_ptr[i * 5 + 3] = rois_cpu_ptr[i * 4 + 2]; + boxes_cpu_ptr[i * 5 + 4] = rois_cpu_ptr[i * 4 + 3]; + } + + // copy boxes_cpu to boxes_mlu + framework::TensorCopy(boxes_cpu, ctx.GetPlace(), dev_ctx, &boxes_mlu); + dev_ctx.Wait(); + + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + const std::vector perm_to_nchw = {0, 3, 1, 2}; + Tensor input_nhwc(in->type()); + Tensor output_nhwc(out->type()); + TransposeFromMLUTensor( + ctx, perm_to_nhwc, in, &input_nhwc, true /*need_reshape_or_alloc*/); + auto output_dims = out->dims(); + output_nhwc.mutable_data( + {output_dims[0], output_dims[2], output_dims[3], output_dims[1]}, + ctx.GetPlace()); + + MLUCnnlTensorDesc input_desc( + input_nhwc, CNNL_LAYOUT_NHWC, ToCnnlDataType(input_nhwc.dtype())); + MLUCnnlTensorDesc boxes_desc(boxes_mlu); + MLUCnnlTensorDesc out_desc( + output_nhwc, CNNL_LAYOUT_NHWC, ToCnnlDataType(output_nhwc.dtype())); + MLUCnnl::RoiAlign(ctx, + pooled_height, + pooled_width, + sampling_ratio, + spatial_scale, + aligned, + input_desc.get(), + GetBasePtr(&input_nhwc), + boxes_desc.get(), + GetBasePtr(&boxes_mlu), + out_desc.get(), + GetBasePtr(&output_nhwc)); + TransposeFromMLUTensor( + ctx, perm_to_nchw, &output_nhwc, out, false /*need_reshape_or_alloc*/); + }; +}; + +template +class ROIAlignGradOpMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* rois = ctx.Input("ROIs"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* in_grad = ctx.Output(framework::GradVarName("X")); + + auto spatial_scale = ctx.Attr("spatial_scale"); + auto sampling_ratio = ctx.Attr("sampling_ratio"); + auto aligned = ctx.Attr("aligned"); + int rois_num = rois->dims()[0]; + + if (!in_grad) { + return; + } + in_grad->mutable_data(ctx.GetPlace()); + + std::vector roi_batch_id_list(rois_num); + auto cplace = platform::CPUPlace(); + int rois_batch_size = 0; + if (ctx.HasInput("RoisNum")) { + auto* rois_num_t = ctx.Input("RoisNum"); + rois_batch_size = rois_num_t->numel(); + std::vector rois_num_list(rois_batch_size); + memory::Copy(cplace, + rois_num_list.data(), + ctx.GetPlace(), + rois_num_t->data(), + sizeof(int) * rois_batch_size, + nullptr /*stream*/); + int last_idx = 0; + for (int i = 0; i < rois_batch_size; i++) { + int end_idx = last_idx + rois_num_list[i]; + for (int j = last_idx; j < end_idx; j++) { + roi_batch_id_list[j] = i; + } + last_idx = end_idx; + } + } else { + auto rois_lod = rois->lod().back(); + rois_batch_size = rois_lod.size() - 1; + for (int i = 0; i < rois_batch_size; i++) { + int start_idx = rois_lod[i]; + int end_idx = rois_lod[i + 1]; + for (int j = start_idx; j < end_idx; j++) { + roi_batch_id_list[j] = i; + } + } + } + + Tensor rois_cpu(framework::TransToPhiDataType(VT::FP32)); + rois_cpu.Resize({rois_num, 4}); + rois_cpu.mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + framework::TensorCopy(*rois, cplace, dev_ctx, &rois_cpu); + dev_ctx.Wait(); + T* rois_cpu_ptr = rois_cpu.mutable_data(platform::CPUPlace()); + + // boxes; [batch_idx, x1, y1, x2, y2] + Tensor boxes_cpu(framework::TransToPhiDataType(VT::FP32)); + Tensor boxes_mlu(framework::TransToPhiDataType(VT::FP32)); + boxes_cpu.Resize({rois_num, 5}); + boxes_mlu.Resize({rois_num, 5}); + T* boxes_cpu_ptr = boxes_cpu.mutable_data(platform::CPUPlace()); + boxes_mlu.mutable_data(ctx.GetPlace()); + for (int i = 0; i < rois_num; ++i) { + boxes_cpu_ptr[i * 5 + 0] = static_cast(roi_batch_id_list[i]); + boxes_cpu_ptr[i * 5 + 1] = rois_cpu_ptr[i * 4 + 0]; + boxes_cpu_ptr[i * 5 + 2] = rois_cpu_ptr[i * 4 + 1]; + boxes_cpu_ptr[i * 5 + 3] = rois_cpu_ptr[i * 4 + 2]; + boxes_cpu_ptr[i * 5 + 4] = rois_cpu_ptr[i * 4 + 3]; + } + + // copy boxes_cpu to boxes_mlu + framework::TensorCopy(boxes_cpu, ctx.GetPlace(), dev_ctx, &boxes_mlu); + dev_ctx.Wait(); + + const std::vector perm_to_nhwc = {0, 2, 3, 1}; + const std::vector perm_to_nchw = {0, 3, 1, 2}; + Tensor grads_nhwc(out_grad->type()); + Tensor grads_image_nhwc(in_grad->type()); + TransposeFromMLUTensor(ctx, + perm_to_nhwc, + out_grad, + &grads_nhwc, + true /*need_reshape_or_alloc*/); + auto grads_image_dims = in_grad->dims(); + grads_image_nhwc.mutable_data({grads_image_dims[0], + grads_image_dims[2], + grads_image_dims[3], + grads_image_dims[1]}, + ctx.GetPlace()); + + MLUCnnlTensorDesc grads_desc( + grads_nhwc, CNNL_LAYOUT_NHWC, ToCnnlDataType(grads_nhwc.dtype())); + MLUCnnlTensorDesc boxes_desc(boxes_mlu); + MLUCnnlTensorDesc grads_image_desc( + grads_image_nhwc, + CNNL_LAYOUT_NHWC, + ToCnnlDataType(grads_image_nhwc.dtype())); + MLUCnnl::RoiAlignBackward(ctx, + sampling_ratio, + spatial_scale, + aligned, + grads_desc.get(), + GetBasePtr(&grads_nhwc), + boxes_desc.get(), + GetBasePtr(&boxes_mlu), + grads_image_desc.get(), + GetBasePtr(&grads_image_nhwc)); + TransposeFromMLUTensor(ctx, + perm_to_nchw, + &grads_image_nhwc, + in_grad, + false /*need_reshape_or_alloc*/); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(roi_align, ops::ROIAlignOpMLUKernel); + +REGISTER_OP_MLU_KERNEL(roi_align_grad, ops::ROIAlignGradOpMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_roi_align_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_roi_align_op_mlu.py new file mode 100644 index 0000000000..daf7e5a853 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_roi_align_op_mlu.py @@ -0,0 +1,222 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle + +paddle.enable_static() +np.random.seed(1243) + + +class TestROIAlignMLUOp(OpTest): + + def set_data(self): + self.init_test_case() + self.make_rois() + self.calc_roi_align() + + seq_len = self.rois_lod[0] + + self.inputs = { + 'X': self.x, + 'ROIs': self.rois[:, 1:5], + 'RoisNum': np.asarray(seq_len).astype('int32') + } + # print("self.inputs: ",self.inputs) + + self.attrs = { + 'spatial_scale': self.spatial_scale, + 'pooled_height': self.pooled_height, + 'pooled_width': self.pooled_width, + 'sampling_ratio': self.sampling_ratio, + 'aligned': self.aligned + } + + self.outputs = {'Out': self.out_data} + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 + self.height = 8 + self.width = 6 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, self.height, self.width) + + self.spatial_scale = 1.0 / 2.0 + self.pooled_height = 2 + self.pooled_width = 2 + self.sampling_ratio = 2 + self.aligned = False + + self.x = np.random.random(self.x_dim).astype('float32') + + def pre_calc(self, x_i, roi_xmin, roi_ymin, roi_bin_grid_h, roi_bin_grid_w, + bin_size_h, bin_size_w): + count = roi_bin_grid_h * roi_bin_grid_w + bilinear_pos = np.zeros( + [self.channels, self.pooled_height, self.pooled_width, count, 4], + np.float32) + bilinear_w = np.zeros([self.pooled_height, self.pooled_width, count, 4], + np.float32) + for ph in range(self.pooled_width): + for pw in range(self.pooled_height): + c = 0 + for iy in range(roi_bin_grid_h): + y = roi_ymin + ph * bin_size_h + (iy + 0.5) * \ + bin_size_h / roi_bin_grid_h + for ix in range(roi_bin_grid_w): + x = roi_xmin + pw * bin_size_w + (ix + 0.5) * \ + bin_size_w / roi_bin_grid_w + if y < -1.0 or y > self.height or \ + x < -1.0 or x > self.width: + continue + if y <= 0: + y = 0 + if x <= 0: + x = 0 + y_low = int(y) + x_low = int(x) + if y_low >= self.height - 1: + y = y_high = y_low = self.height - 1 + else: + y_high = y_low + 1 + if x_low >= self.width - 1: + x = x_high = x_low = self.width - 1 + else: + x_high = x_low + 1 + ly = y - y_low + lx = x - x_low + hy = 1 - ly + hx = 1 - lx + for ch in range(self.channels): + bilinear_pos[ch, ph, pw, c, 0] = x_i[ch, y_low, + x_low] + bilinear_pos[ch, ph, pw, c, 1] = x_i[ch, y_low, + x_high] + bilinear_pos[ch, ph, pw, c, 2] = x_i[ch, y_high, + x_low] + bilinear_pos[ch, ph, pw, c, 3] = x_i[ch, y_high, + x_high] + bilinear_w[ph, pw, c, 0] = hy * hx + bilinear_w[ph, pw, c, 1] = hy * lx + bilinear_w[ph, pw, c, 2] = ly * hx + bilinear_w[ph, pw, c, 3] = ly * lx + c = c + 1 + return bilinear_pos, bilinear_w + + def calc_roi_align(self): + self.out_data = np.zeros( + (self.rois_num, self.channels, self.pooled_height, + self.pooled_width)).astype('float32') + + offset = 0.5 if self.aligned else 0. + for i in range(self.rois_num): + roi = self.rois[i] + roi_batch_id = int(roi[0]) + x_i = self.x[roi_batch_id] + roi_xmin = roi[1] * self.spatial_scale - offset + roi_ymin = roi[2] * self.spatial_scale - offset + roi_xmax = roi[3] * self.spatial_scale - offset + roi_ymax = roi[4] * self.spatial_scale - offset + + roi_width = roi_xmax - roi_xmin + roi_height = roi_ymax - roi_ymin + if not self.aligned: + roi_width = max(roi_width, 1) + roi_height = max(roi_height, 1) + + bin_size_h = float(roi_height) / float(self.pooled_height) + bin_size_w = float(roi_width) / float(self.pooled_width) + roi_bin_grid_h = self.sampling_ratio if self.sampling_ratio > 0 else \ + math.ceil(roi_height / self.pooled_height) + roi_bin_grid_w = self.sampling_ratio if self.sampling_ratio > 0 else \ + math.ceil(roi_width / self.pooled_width) + count = max(int(roi_bin_grid_h * roi_bin_grid_w), 1) + pre_size = count * self.pooled_width * self.pooled_height + bilinear_pos, bilinear_w = self.pre_calc(x_i, roi_xmin, roi_ymin, + int(roi_bin_grid_h), + int(roi_bin_grid_w), + bin_size_h, bin_size_w) + for ch in range(self.channels): + align_per_bin = (bilinear_pos[ch] * bilinear_w).sum(axis=-1) + output_val = align_per_bin.mean(axis=-1) + self.out_data[i, ch, :, :] = output_val + + def make_rois(self): + rois = [] + self.rois_lod = [[]] + for bno in range(self.batch_size): + # for i in range(bno + 1): + # self.rois_lod[0].append(bno) + self.rois_lod[0].append(1) + x1 = np.random.randint( + 0, self.width // self.spatial_scale - self.pooled_width) + y1 = np.random.randint( + 0, self.height // self.spatial_scale - self.pooled_height) + + x2 = np.random.randint(x1 + self.pooled_width, + self.width // self.spatial_scale) + y2 = np.random.randint(y1 + self.pooled_height, + self.height // self.spatial_scale) + + roi = [bno, x1, y1, x2, y2] + rois.append(roi) + + self.rois_num = len(rois) + self.rois = np.array(rois).astype("float32") + + def setUp(self): + self.op_type = "roi_align" + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + self.set_data() + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + +class TestROIAlignOpWithMinusSample(TestROIAlignMLUOp): + + def init_test_case(self): + self.batch_size = 3 + self.channels = 3 + self.height = 8 + self.width = 6 + + # n, c, h, w + self.x_dim = (self.batch_size, self.channels, self.height, self.width) + + self.spatial_scale = 1.0 / 2.0 + self.pooled_height = 2 + self.pooled_width = 2 + self.sampling_ratio = -1 + self.aligned = False + + self.x = np.random.random(self.x_dim).astype('float32') + + +if __name__ == '__main__': + unittest.main() -- GitLab