# -*- coding: utf-8 -*- # Copyright 2019 - present, Facebook, Inc # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # --------------------------------------------------------------------- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # # This file has been modified by Megvii ("Megvii Modifications"). # All Megvii Modifications are Copyright (C) 2014-2020 Megvii Inc. All rights reserved. # --------------------------------------------------------------------- from collections import namedtuple import megengine.module as M class Conv2d(M.Conv2d): """ A wrapper around :class:`megengine.module.Conv2d`. """ def __init__(self, *args, **kwargs): """ Extra keyword arguments supported in addition to `megengine.module.Conv2d`. Args: norm (M.Module, optional): a normalization layer activation (callable(Tensor) -> Tensor): a callable activation function """ norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) self.norm = norm self.activation = activation def forward(self, x): x = super().forward(x) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class ShapeSpec(namedtuple("_ShapeSpec", ["channels", "height", "width", "stride"])): """ A simple structure that contains basic shape specification about a tensor. Useful for getting the modules output channels when building the graph. """ def __new__(cls, channels=None, height=None, width=None, stride=None): return super().__new__(cls, channels, height, width, stride)