/** * \file src/opr/impl/dnn/pooling.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/dnn/pooling.h" #include "megbrain/graph/grad_impl.h" #include "megbrain/opr/search_policy/algo_chooser.h" #include "../internal/megdnn_opr_wrapper.inl" #include "../search_policy/workspace_need_limit_getter.inl" using namespace mgb; using namespace opr; MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); PoolingForward::PoolingForward(VarNode* i0, const Param& param, const ExecutionPolicy& policy, const OperatorNodeConfig& config) : Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "pooling", {i0}}) { init_megdnn_opr(*this, param); add_input({i0}); m_policy = policy; intl::MegDNNOprInitPostCtor::apply(*this); } SymbolVar PoolingForward::make(SymbolVar i0, const Param& param, const OperatorNodeConfig& config, const ExecutionPolicy& policy) { intl::MegDNNOprInitInputsModifier::apply(param, {&i0}); return i0.insert_single_output_opr(i0.node(), param, policy, config); } void PoolingForward::init_output_static_infer_desc() { Super::set_nr_managed_outputs(this->output().size() - 1); Super::Super::init_output_static_infer_desc(); init_output_static_infer_desc_workspace( intl::AutoAddWorkspaceNeedLimitGetter::val); } size_t PoolingForward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { return AlgoChooser::setup_algo( {TensorLayout{input_shapes[0], input(0)->dtype(), input(0)->format()}, {output_shapes[0], output(0)->dtype(), output(0)->format()}}, megdnn_opr(), this, false); } #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PoolingForward) { mgb_assert(wrt_idx == 0); SymbolVar grad = PoolingBackward::make( opr.input(0), opr.output(0), out_grad[0], opr.param()); return grad.node(); } #endif MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward); PoolingBackward::PoolingBackward(VarNode* i0, VarNode* i1, VarNode* i2, const Param& param, const ExecutionPolicy& policy, const OperatorNodeConfig& config) : Super( OperatorNodeBaseCtorParam{ i0->owner_graph(), config, "pooling_bwd", {i0}}, 0, true) { init_megdnn_opr(*this, param); add_input({i0, i1, i2}); intl::MegDNNOprInitPostCtor::apply(*this); } SymbolVar PoolingBackward::make(SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, const OperatorNodeConfig& config, const ExecutionPolicy& policy) { intl::MegDNNOprInitInputsModifier::apply(param, {&i0, &i1, &i2}); return i0.insert_single_output_opr( i0.node(), i1.node(), i2.node(), param, policy, config); } size_t PoolingBackward::get_workspace_size_bytes( const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const { return AlgoChooser::setup_algo( {TensorLayout{input_shapes[0], input(0)->dtype(), input(0)->format()}, {input_shapes[1], input(1)->dtype(), input(1)->format()}, {input_shapes[2], input(2)->dtype(), input(2)->format()}, {output_shapes[0], output(0)->dtype(), output(0)->format()}}, megdnn_opr(), this, false); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}