scalar.cpp 15.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/**
 * \file imperative/src/impl/transformations/trace.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/imperative/transformations/scalar.h"

#include "megbrain/imperative/ops/autogen.h"
15
#include "megbrain/imperative/ops/utility.h"
16
#include "megbrain/imperative/utils/stats.h"
17 18 19 20 21 22

namespace mgb {
namespace imperative {

namespace {

23 24
using ScalarRule = ValueRefList (*)(
        const OpDef&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&);
25
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules;
26 27 28 29 30 31 32 33 34 35

ValueRef make_scalar_shape(CompNode device) {
    HostTensorND scalar_shape(device, {1}, dtype::Int32());
    scalar_shape.ptr<dt_int32>()[0] = 1;
    return imperative::apply(
            CreateTensor(CreateTensor::Const, device, scalar_shape.layout()),
            HostStorage::make(scalar_shape.storage()))[0];
}

bool is_scalar_shape(ValueRef shape) {
36
    // may have performance issue
37 38 39 40 41 42 43 44
    auto shape_of_shape = shape.shape();
    if (!shape_of_shape) {
        // assume not scalar
        return false;
    }
    return *shape_of_shape == ValueShape{0};
}

45 46 47 48
template <
        typename T,
        ValueRefList (*rule)(
                const T&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&)>
49 50
void register_scalar_rule() {
    scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs,
51 52 53
                                     Span<bool> inputs_mask,
                                     const Type<ScalarValue>& value_type) {
        return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask, value_type);
54 55 56
    };
}

57 58
template <typename TOpDef, size_t nr_inputs>
ValueRefList elemwise_rule(
59 60
        const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
61 62 63
    if constexpr (nr_inputs != 0) {
        mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch");
    }
64
    bool all_scalar = true;
65 66
    for (auto&& input_mask : inputs_mask) {
        if (!input_mask) {
67 68 69
            all_scalar = false;
        }
    }
70
    auto outputs = imperative::apply(op_def, inputs);
71
    if (all_scalar) {
72
        outputs[0] = scalar_type.make(outputs[0]);
73
    }
74
    return outputs;
75 76
}

77
ValueRefList remove_axis_rule(
78 79
        const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
80 81 82
    mgb_assert(!inputs_mask.item());
    bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size();
    if (is_scalar && remove_axis.axis.size() == 1) {
83
        return {scalar_type.make(inputs.item())};
84 85
    }
    auto outputs = imperative::apply(remove_axis, inputs);
86
    if (is_scalar) {
87
        outputs[0] = scalar_type.make(outputs[0]);
88
    }
89
    return outputs;
90 91
}

92
ValueRefList reduce_rule(
93 94
        const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
95
    if (inputs.size() == 1) {
96
        return imperative::apply(reduce, inputs);
97 98 99 100
    }
    mgb_assert(inputs.size() == 2);
    bool is_scalar = is_scalar_shape(inputs[1]);
    if (is_scalar) {
101
        CompNode device = *inputs[0].device();
102
        return {scalar_type.make(
103
                imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])};
104
    }
105
    return imperative::apply(reduce, inputs);
106 107
}

108 109
ValueRefList collective_comm_rule(
        const CollectiveComm& collective_comm, Span<ValueRef> inputs,
110
        Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
111 112 113 114 115 116 117 118 119
    mgb_assert(inputs.size() == 1);
    static std::unordered_set<CollectiveComm::Mode> modes = {
            CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN,
            CollectiveComm::Mode::ALL_REDUCE_SUM, CollectiveComm::Mode::BROADCAST,
            CollectiveComm::Mode::REDUCE_SUM,
    };
    if (modes.count(collective_comm.mode) == 0) {
        return imperative::apply(collective_comm, inputs);
    }
120
    if (inputs_mask.item()) {
121
        return {scalar_type.make(imperative::apply(collective_comm, inputs[0])[0])};
122 123 124 125 126
    } else {
        return imperative::apply(collective_comm, inputs);
    }
}

127 128
ValueRefList param_pack_split_rule(
        const ParamPackSplit& param_pack_split, Span<ValueRef> inputs,
129
        Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
130
    auto outputs = imperative::apply(param_pack_split, inputs);
131 132 133 134
    size_t nr_outputs = outputs.size();
    mgb_assert(nr_outputs == param_pack_split.shapes.size());
    for (size_t i = 0; i < nr_outputs; ++i) {
        if (param_pack_split.shapes[i].empty()) {
135
            outputs[i] = scalar_type.make(outputs[i]);
136 137 138 139 140
        }
    }
    return outputs;
}

141 142 143 144
ValueRefList dot_rule(
        const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
    return {scalar_type.make(imperative::apply(dot, inputs)[0])};
145 146
}

147
ValueRefList add_axis_rule(
148 149
        const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
150
    mgb_assert(inputs.size() == 1);
151
    if (inputs_mask.item()) {
152 153
        mgb_assert(add_axis.axis[0] == 0);
        if (add_axis.axis.size() == 1) {
154
            return {inputs[0]};
155 156
        } else {
            std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end());
157
            return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]);
158 159 160 161 162 163
        }
    } else {
        return imperative::apply(add_axis, inputs);
    }
}

164
ValueRefList remote_recv_rule(
165 166
        const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
167 168 169 170 171 172 173
    if (remote_recv.shape.empty()) {
        std::vector<int32_t> shape = {1};
        auto remote_recv_no_scalar = RemoteRecv::make(
                remote_recv.key, remote_recv.addr, remote_recv.port,
                remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype,
                remote_recv.backend);
        remote_recv_no_scalar->set_scope(remote_recv.scope());
174
        return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs);
175
    } else {
176
        return imperative::apply(remote_recv, inputs);
177 178 179
    }
}

180 181
ValueRefList check_no_finite_rule(
        const CheckNonFinite& check_no_finite, Span<ValueRef> inputs,
182
        Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
183
    auto outputs = imperative::apply(check_no_finite, inputs);
184
    mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch");
185
    outputs.back() = scalar_type.make(outputs.back());
186
    for (size_t i = 0; i < inputs.size(); ++i) {
187
        if (inputs_mask[i]) {
188
            outputs[i] = scalar_type.make(outputs[i]);
189 190 191 192 193
        }
    }
    return outputs;
}

194
ValueRefList subtensor_rule(
195 196
        const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
197 198
    mgb_assert(inputs.size() >= 1);
    auto input = inputs[0];
199
    bool is_scalar;
200
    mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input");
201
    if (auto shape = input.shape()) {
202
        size_t ndim = shape->ndim;
203 204 205 206
        for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
            if (idx) {
                ndim--;
            }
207
        }
208 209
        is_scalar = ndim == 0;
    } else {
210
        // assume not scalar
211
        is_scalar = false;
212
    }
213
    auto outputs = imperative::apply(subtensor, inputs);
214
    if (is_scalar) {
215
        outputs[0] = scalar_type.make(outputs[0]);
216
    }
217
    return outputs;
218 219
}

220
ValueRefList get_var_shape_rule(
221 222
        const GetVarShape& get_var_shape, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
223 224
    bool all_scalar = true;
    mgb_assert(inputs.size() >= 1);
225 226
    for (auto&& input_mask : inputs_mask) {
        if (!input_mask) {
227 228 229 230
            all_scalar = false;
        }
    }
    if (all_scalar) {
231
        auto device = inputs[0].device();
232 233 234 235 236 237 238
        auto storage = HostStorage::make(*device);
        // storage->ensure_size(1);
        return imperative::apply(
                CreateTensor(
                        CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}),
                storage);
    } else {
239
        return imperative::apply(get_var_shape, inputs);
240 241 242
    }
}

243
ValueRefList reshape_rule(
244 245
        const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
246 247 248 249
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
    bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) ||
                     (nr_inp == 1 && reshape.shape.size() == 0);
250
    if (is_scalar) {
251
        return {scalar_type.make(imperative::apply(
252
                reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
253
    } else {
254
        return imperative::apply(reshape, inputs);
255 256 257
    }
}

258
ValueRefList broadcast_rule(
259 260
        const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
261 262 263 264
    mgb_assert(inputs.size() == 1 || inputs.size() == 2);
    size_t nr_inp = inputs.size();
    bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) ||
                     (nr_inp == 1 && broadcast.shape.size() == 0);
265
    if (is_scalar) {
266
        return {scalar_type.make(imperative::apply(
267
                broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
268
    } else {
269
        return imperative::apply(broadcast, inputs);
270 271 272
    }
}

273
template <typename T>
274 275 276
ValueRefList subgraph_op_rule(
        const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask,
        const Type<ScalarValue>& scalar_type) {
277 278
    // TODO: add flag instead of assume
    bool all_scalar = true;
279 280
    for (auto&& input_mask : inputs_mask) {
        if (!input_mask) {
281 282 283
            all_scalar = false;
        }
    }
284
    auto outputs = imperative::apply(op, inputs);
285 286
    if (all_scalar) {
        for (auto& output : outputs) {
287
            output = scalar_type.make(output);
288 289 290 291 292
        }
    }
    return outputs;
}

293 294
struct ScalarRuleRegistry {
    ScalarRuleRegistry() {
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
        register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>();
        register_scalar_rule<RemoveAxis, remove_axis_rule>();
        register_scalar_rule<Reduce, reduce_rule>();
        register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>();
        register_scalar_rule<CollectiveComm, collective_comm_rule>();
        register_scalar_rule<ParamPackSplit, param_pack_split_rule>();
        register_scalar_rule<Dot, dot_rule>();
        register_scalar_rule<AddAxis, add_axis_rule>();
        register_scalar_rule<RemoteRecv, remote_recv_rule>();
        register_scalar_rule<CheckNonFinite, check_no_finite_rule>();
        register_scalar_rule<Subtensor, subtensor_rule>();
        register_scalar_rule<GetVarShape, get_var_shape_rule>();
        register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>();
        register_scalar_rule<Reshape, reshape_rule>();
        register_scalar_rule<Broadcast, broadcast_rule>();
        register_scalar_rule<Copy, elemwise_rule<Copy, 1>>();
        register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>();
        register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>();
        register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>();
314 315 316 317
    }
} _;
}  // namespace

318 319 320
ValueRefList ScalarTransformation::apply_get_attr(
        const GetAttr& get_attr, Span<ValueRef> inputs) {
    auto&& input = inputs.item();
321
    bool is_scalar = input.is(m_value_type);
322 323 324
    if (!is_scalar) {
        return imperative::apply(get_attr, input);
    }
325
    auto unwrapped_input = input.cast(m_value_type).value();
326 327 328
    if (get_attr.attr() == GetAttr::Shape) {
        if (!m_empty_shape) {
            m_empty_shape = ShapeValue::make();
329
        }
330 331 332 333 334
        return {m_empty_shape};
    } else {
        auto outputs = imperative::apply(get_attr, unwrapped_input);
        auto& output = outputs[0];
        switch (get_attr.attr()) {
335 336 337 338 339 340
            case GetAttr::Value: {
                auto& hv = output.cast<HostValue>();
                mgb_assert(
                        hv.shape() == ValueShape({1}),
                        "underlying value should has shape {1}, got %s",
                        hv.shape().to_string().c_str());
341 342
                output = HostValue::make(hv.dtype(), ValueShape(), hv.storage());
                break;
343 344 345 346 347 348 349
            }
            case GetAttr::Data: {
                auto& dv = output.cast<DeviceValue>();
                mgb_assert(
                        dv.shape() == ValueShape({1}),
                        "underlying value should has shape {1}, got %s",
                        dv.shape().to_string().c_str());
350 351
                output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage());
                break;
352 353
            }
            default:
354 355 356 357 358 359 360 361 362 363 364
                break;
        }
        return outputs;
    }
}

ValueRefList ScalarTransformation::apply_transformation(
        const Operator& op, Span<ValueRef> inputs) {
    if (auto* get_attr = op.as<GetAttr>()) {
        // fastpath for GetAttr
        return apply_get_attr(*get_attr, inputs);
365 366 367 368
    } else if (auto* apply_op = op.as<ApplyOp>()) {
        if (apply_op->op().same_type<FastpathCopy>()) {
            return inputs[0];
        }
369 370 371
    }
    size_t nr_inputs = inputs.size();
    ValueRefList unwrapped_inputs(nr_inputs);
372
    SmallVector<bool> inputs_mask(nr_inputs);
373
    for (size_t i = 0; i < inputs.size(); ++i) {
374
        if (auto&& scalar_value = inputs[i].as_ref(m_value_type)) {
375 376 377 378 379 380 381 382 383 384 385
            unwrapped_inputs[i] = scalar_value->value();
            inputs_mask[i] = true;
        } else {
            unwrapped_inputs[i] = inputs[i];
            inputs_mask[i] = false;
        }
    }
    auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); };
    if (auto apply_op = op.as<ApplyOp>()) {
        auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
        if (iter != scalar_rules.end()) {
386 387
            return iter->second(
                    apply_op->op(), unwrapped_inputs, inputs_mask, m_value_type);
388 389 390 391 392 393 394 395 396 397
        } else {
            // TODO: repeat op
            return fallback();
        }
    } else if (auto* create_tensor = op.as<CreateTensor>()) {
        if (create_tensor->shape().is_scalar()) {
            ValueShape scalar_shape = {1};
            CreateTensor scalar_op(
                    create_tensor->kind(), create_tensor->device(),
                    create_tensor->dtype(), scalar_shape);
398
            return {m_value_type.make(imperative::apply(scalar_op, inputs)[0])};
399 400
        } else {
            return imperative::apply(op, inputs);
401 402
        }
    } else if (op.as<IsScalar>()) {
403 404
        mgb_assert(nr_inputs == 1);
        return {BoolValue::make(inputs_mask[0])};
405
    } else if (op.is<Operator::IdentityLike>()) {
406 407 408
        mgb_assert(nr_inputs == 1);
        bool is_scalar = inputs_mask[0];
        auto outputs = fallback();
409
        if (is_scalar) {
410
            outputs[0] = m_value_type.make(outputs[0]);
411
        }
412
        return outputs;
413
    } else {
414
        return fallback();
415 416 417 418 419
    }
};

}  // namespace imperative
}  // namespace mgb