reduce.cpp 8.3 KB
Newer Older
1
#include "megbrain/graph/symbol_var.h"
2
#include "megbrain/imperative/ops/autogen.h"
3
#include "megbrain/imperative/proxy_graph_detail.h"
M
Megvii Engine Team 已提交
4
#include "megbrain/opr/basic_arith.h"
5 6 7 8
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megdnn/dtype.h"
9

10
#include "../blob_manager_impl.h"
11
#include "../dnn_op_helper.h"
M
Megvii Engine Team 已提交
12
#include "../op_trait.h"
13 14 15 16 17 18 19

namespace mgb {
namespace imperative {
namespace {
namespace reduce {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
    auto&& reduce = static_cast<const Reduce&>(def);
20 21 22
    auto comp_node = inputs[0]->comp_node();
    OperatorNodeConfig config{reduce.make_name(), comp_node, inputs[0]->dtype()};

23 24 25
    if (inputs.size() > 1) {
        return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
    }
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

    using Param = megdnn::param::Reduce;
    auto param = reduce.param();
    if (param.axis < 0) {
        param.axis = inputs[0]->shape().ndim + param.axis;
    }

    SymbolVar target_shape = (cg::VarNode*)nullptr;
    if (param.axis == INT_MAX) {
        DTypeScalar vi{1};
        // auto graph = ComputingGraph::make();
        auto graph = inputs[0]->owner_graph();
        target_shape = opr::ImmutableTensor::make(*graph, vi, config);
    }
    auto res = opr::Reduce::make(inputs[0], param, target_shape, config);
    if (!reduce.keepdim && param.axis != INT_MAX) {
        using Desc = opr::AxisAddRemove::AxisDesc;
        std::vector<Desc> remove_param;
        remove_param.push_back(Desc::make_remove(param.axis));
        OperatorNodeConfig remove_config{
                def.make_name(), comp_node, inputs[0]->dtype()};
        return opr::AxisAddRemove::make(res, remove_param, remove_config);
    }
    return res;
50 51 52 53
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
    auto* node = &node_->cast_final_safe<opr::Reduce>();
54
    return Reduce::make(node->param(), true);
55 56
}

57
// TODO: using this for apply_on_physical_tensor
M
Megvii Engine Team 已提交
58
bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) {
59 60 61 62 63 64 65 66 67 68 69 70
    auto&& reduce = static_cast<const Reduce&>(def);
    if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) {
        auto shape_tensor = inputs[1]->get_value();
        TensorShape shape;
        cg::copy_tensor_value_to_shape(shape, shape_tensor.proxy_to_default_cpu());
        if (shape.eq_shape(inputs[0]->shape())) {
            return true;
        }
    }
    return false;
}

71
SmallVector<TensorPtr> apply_on_physical_tensor(
72 73
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
74
    if (memory_forward_success(def, inputs)) {
75 76
        return {Tensor::make(
                inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())};
77
    }
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119

    auto size = inputs.size();
    if (size > 1) {
        return proxy_graph_detail::apply_on_physical_tensor(
                def, inputs, output_descs, validated);
    }

    auto comp_node = inputs[0]->comp_node();
    using TensorND = megdnn::TensorND;
    auto&& op_def = def.cast_final_safe<Reduce>();
    SmallVector<TensorND> inp_tensornds;
    inp_tensornds.reserve(inputs.size());
    auto src = inputs[0]->layout();

    DnnOprCaller<megdnn::Reduce> dnn_op(comp_node);
    dnn_op.op->param() = op_def.param();
    auto axis = op_def.param().axis;
    auto keepdim = op_def.keepdim;

    if (axis < 0) {
        axis = inputs[0]->layout().ndim + axis;
    }

    dnn_op.op->param().axis = axis == INT_MAX ? 0 : axis;

    if (axis == INT_MAX) {
        src.shape[0] = src.total_nr_elems();
        src.ndim = 1;
        src.init_contiguous_stride();
    }
    TensorLayout layout{src.dtype};
    dnn_op.op->deduce_layout(src, layout);

    if (inputs[0]->layout().is_empty()) {
        inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src);

        auto mode = op_def.param().mode;

        if (!keepdim && src.ndim > 1) {
            layout.remove_axis_inplace(axis);
            layout.init_contiguous_stride();
        }
120 121
        auto out = Tensor::make(layout, comp_node);

122 123 124
        std::string err_msg;
        switch (mode) {
            case Reduce::Mode::SUM:
125 126
                if (!out->empty()) {
                    dev_tensor_memset(out->dev_tensor(), 0);
127 128 129
                }
                break;
            case Reduce::Mode::PRODUCT:
130
                if (!out->empty()) {
131
                    DnnOprCaller<megdnn::Fill> fill_op(comp_node);
132
                    fill_op.op->param() = 1;
133
                    fill_op.op->exec(out->dnn_tensor(), {});
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
                }
                break;
            case Reduce::Mode::MEAN:
                err_msg = "mean";
                break;
            case Reduce::Mode::MIN:
                err_msg = "min";
                break;
            case Reduce::Mode::MAX:
                err_msg = "max";
                break;
            case Reduce::Mode::SUM_SQR:
                err_msg = "sum_sqr";
                break;
            default:
                mgb_throw(MegBrainError, "bad reduce mode");
        }
        if (!err_msg.empty()) {
            mgb_throw(
                    MegBrainError, "empty input is not allowed for reduce mode: %s",
                    err_msg.c_str());
        }
156
        return {out};
157 158 159 160 161 162 163
    }

    auto dnn_ten = inputs[0]->dnn_tensor();
    dnn_ten.layout = src;
    inp_tensornds.push_back(dnn_ten);

    auto wk_size = dnn_op.op->get_workspace_in_bytes(src, layout);
164 165
    auto dnn_wk = dnn_op.create_workspace(wk_size);
    TensorLayout ori_layout = layout;
166 167

    if (!keepdim && src.ndim > 1) {
168 169
        layout.remove_axis_inplace(axis);
        layout.init_contiguous_stride();
170 171
    }

172 173 174 175 176 177 178
    auto out = Tensor::make(layout, comp_node);
    auto dnn_out = out->dnn_tensor();
    dnn_out.layout = ori_layout;

    dnn_op.op->exec(inp_tensornds[0], dnn_out, dnn_wk);

    return {out};
179 180
}

181 182
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
183 184 185 186 187 188 189
    auto&& op_def = def.cast_final_safe<Reduce>();
    auto axis = op_def.param().axis;
    auto keepdim = op_def.keepdim;

    size_t size = inputs.size();
    SmallVector<LogicalTensorDesc> dests(size);

190 191 192 193 194 195
    for (size_t i = 0; i < size; i++) {
        if (inputs[i].layout.ndim == 0) {
            return {{{TensorLayout(inputs[0].layout.dtype), inputs[0].comp_node}},
                    false};
        }
    }
196 197 198
    if (size > 1) {
        auto [output_descs, validated] =
                proxy_graph_detail::infer_output_attrs_fallible(def, inputs);
199 200 201 202
        if (!inputs[1].value.empty()) {
            cg::copy_tensor_value_to_shape(output_descs[0].layout, inputs[1].value);
            output_descs[0].layout.init_contiguous_stride();
        }
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
        return {output_descs, validated};
    }

    if (axis < 0) {
        axis = inputs[0].layout.ndim + axis;
    }

    if (axis == INT_MAX || inputs[0].layout.ndim == 1) {
        TensorLayout layout{inputs[0].layout.dtype};
        layout.shape[0] = 1;
        layout.ndim = 1;
        dests[0].layout = layout;
        dests[0].comp_node = inputs[0].comp_node;
    } else {
        for (size_t i = 0; i < size; ++i) {
            dests[i].comp_node = inputs[i].comp_node;
            dests[i].layout = inputs[i].layout;
220
            if (!keepdim && dests[i].layout.ndim > 1) {
221 222 223 224 225 226
                dests[i].layout.remove_axis_inplace(axis);
            } else {
                dests[i].layout.shape[axis] = 1;
            }
            dests[i].layout.init_contiguous_stride();
        }
227
    }
228 229

    return {dests, true};
230 231
}

232 233 234 235 236 237 238 239 240
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
    layout_checker[0] = [](const TensorLayout& layout) {
        return layout.is_contiguous();
    };
    return layout_checker;
}

241 242 243
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
244
        .apply_on_physical_tensor(apply_on_physical_tensor)
245
        .infer_output_attrs_fallible(infer_output_attrs_fallible)
246
        .get_input_layout_constraint(get_input_layout_constraint)
247 248 249 250 251 252 253
        .fallback();
}  // namespace reduce
}  // namespace
}  // namespace imperative
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}