diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index 2895688b53ecda1fb8d03a06a4cecdff3748bd1c..b63883e2ef8ead62f8ddba7c75c186beba2755cc 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -12,6 +12,8 @@ #include "megbrain/imperative/ops/autogen.h" #include "megbrain/opr/tensor_manip.h" +#include "megbrain/graph/helper.h" + #include "../op_trait.h" namespace mgb { @@ -83,10 +85,46 @@ std::tuple, bool> infer_output_attrs_fallible( return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; } +std::tuple, SmallVector> infer_output_mem_desc( + const OpDef& def, + const SmallVector& inputs_tensors, + const SmallVector& inputs_mems) { + auto& input = inputs_tensors[0]; + TensorShape target_shape; + cg::copy_tensor_value_to_shape(target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu()); + // TODO: memory forward + // if (input->shape().eq_shape(target_shape)) { + // return {{{input->layout(), 0, input->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, {}}; + // } + return {{{{target_shape, input->dtype()}, 0, input->comp_node(), StorageIdentifier::make(0)}}, {}}; +} + +void execute( + const OpDef& def, + SmallVector inputs, + SmallVector outputs, + SmallVector workspace) { + if (outputs[0]->layout().is_empty()) { + return; + } + if (inputs[0]->shape().eq_shape(outputs[0]->shape())) { + mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout())); + // TODO: memory forward + // mgb_assert(inputs[0]->offset() == outputs[0]->offset()); + // mgb_assert(inputs[0]->blob() == outputs[0]->blob()); + outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor()); + } else { + TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape()); + outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout))); + } +} + OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .infer_output_attrs_fallible(infer_output_attrs_fallible) + .infer_output_mem_desc(infer_output_mem_desc) + .execute(execute) .fallback(); } // broadcast diff --git a/imperative/src/impl/ops/reduce.cpp b/imperative/src/impl/ops/reduce.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9aa65203e9454dbd1d87ab6699eb747065de5ab --- /dev/null +++ b/imperative/src/impl/ops/reduce.cpp @@ -0,0 +1,47 @@ +/** + * \file imperative/src/impl/ops/reduce.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/basic_arith.h" + +#include "../op_trait.h" +#include "../dnn_op_helper.h" + +namespace mgb { +namespace imperative { +namespace { +namespace reduce { +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& reduce = static_cast(def); + OperatorNodeConfig config{reduce.make_name()}; + if (inputs.size() > 1) { + return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); + } else { + return opr::Reduce::make(inputs[0], reduce.param(), + (cg::VarNode*)nullptr, config); + } +} + +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return Reduce::make(node->param()); +} + +OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) + .make_from_op_node(make_from_op_node) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // namespace reduce +} // namespace +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index edb76da194df8ba2ceb585cfeee8f833358b8391..ac8a50bea08382776f4e2abb08d5b08817d6cdfa 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -116,31 +116,6 @@ OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); } // namespace top_k } // namespace -namespace { -namespace reduce { -auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& reduce = static_cast(def); - OperatorNodeConfig config{reduce.make_name()}; - if (inputs.size() > 1) { - return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); - } else { - return opr::Reduce::make(inputs[0], reduce.param(), - (cg::VarNode*)nullptr, config); - } -} - -std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { - auto* node = &node_->cast_final_safe(); - return Reduce::make(node->param()); -} - -OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) - .make_from_op_node(make_from_op_node) - .apply_on_var_node(apply_on_var_node) - .fallback(); -} // namespace reduce -} // namespace - namespace { namespace adaptive_pooling { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {