/** * \file imperative/src/impl/transformations/trace.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/transformations/scalar.h" #include "megbrain/imperative/ops/autogen.h" namespace mgb { namespace imperative { namespace { using ScalarRule = std::function(const OpDef&, Span)>; static std::unordered_map< Typeinfo*, std::function(const OpDef&, Span)>> scalar_rules; ValueRef unwrap_input(ValueRef input) { if (auto scalar_input = input.as_ref()) { return scalar_input->value(); } else { return input; } } std::vector unwrap_inputs(Span inputs) { std::vector unwrapped_inputs; for (auto&& input : inputs) { unwrapped_inputs.push_back(unwrap_input(input)); } return unwrapped_inputs; } ValueRef make_scalar_shape(CompNode device) { HostTensorND scalar_shape(device, {1}, dtype::Int32()); scalar_shape.ptr()[0] = 1; return imperative::apply( CreateTensor(CreateTensor::Const, device, scalar_shape.layout()), HostStorage::make(scalar_shape.storage()))[0]; } bool is_scalar_shape(ValueRef shape) { if (shape.is()) { return false; } // may have performance issue auto shape_of_shape = shape.shape(); if (!shape_of_shape) { // assume not scalar return false; } return *shape_of_shape == ValueShape{0}; } template void register_scalar_rule(std::vector (*rule)(const T&, Span)) { scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span inputs) { return (*rule)(def.cast_final_safe(), inputs); }; } std::vector elemwise_rule(const Elemwise& elem, Span inputs) { bool all_scalar = true; for (auto&& input : inputs) { if (!input.is()) { all_scalar = false; break; } } auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; if (all_scalar) { return {ScalarValue::make(output)}; } else { return {output}; } } std::vector remove_axis_rule( const RemoveAxis& remove_axis, Span inputs) { mgb_assert(inputs.size() == 1); mgb_assert(!inputs[0].is()); auto output = imperative::apply(remove_axis, inputs)[0]; bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); if (is_scalar) { return {ScalarValue::make(output)}; } else { return {output}; } } std::vector reduce_rule(const Reduce& reduce, Span inputs) { if (inputs.size() == 1) { return imperative::apply(reduce, unwrap_inputs(inputs)); } mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); if (is_scalar) { auto unwrapped_input = unwrap_input(inputs[0]); CompNode device = *unwrapped_input.device(); return {ScalarValue::make(imperative::apply( reduce, unwrapped_input, make_scalar_shape(device))[0])}; } auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; if (is_scalar) { return {ScalarValue::make(output)}; } else { return {output}; } } std::vector typecvt_rule(const TypeCvt& typecvt, Span inputs) { mgb_assert(inputs.size() == 1); if (auto scalar_input = inputs[0].as_ref()) { return {ScalarValue::make( imperative::apply(typecvt, scalar_input->value())[0])}; } else { return imperative::apply(typecvt, inputs); } } std::vector collective_comm_rule( const CollectiveComm& collective_comm, Span inputs) { mgb_assert(inputs.size() == 1); static std::unordered_set modes = { CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, CollectiveComm::Mode::ALL_REDUCE_SUM, CollectiveComm::Mode::BROADCAST, CollectiveComm::Mode::REDUCE_SUM, }; if (modes.count(collective_comm.mode) == 0) { return imperative::apply(collective_comm, inputs); } if (auto scalar_input = inputs[0].as_ref()) { return {ScalarValue::make( imperative::apply(collective_comm, scalar_input->value())[0])}; } else { return imperative::apply(collective_comm, inputs); } } std::vector param_pack_split_rule( const ParamPackSplit& param_pack_split, Span inputs) { auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); size_t nr_outputs = outputs.size(); mgb_assert(nr_outputs == param_pack_split.shapes.size()); for (size_t i = 0; i < nr_outputs; ++i) { if (param_pack_split.shapes[i].empty()) { outputs[i] = ScalarValue::make(outputs[i]); } } return outputs; } std::vector dot_rule(const Dot& dot, Span inputs) { return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; } std::vector add_axis_rule(const AddAxis& add_axis, Span inputs) { mgb_assert(inputs.size() == 1); if (auto scalar_input = inputs[0].as_ref()) { mgb_assert(add_axis.axis[0] == 0); if (add_axis.axis.size() == 1) { return {scalar_input->value()}; } else { std::vector axis(add_axis.axis.begin() + 1, add_axis.axis.end()); return imperative::apply( ApplyOp(*AddAxis::make(axis, add_axis.scope())), scalar_input->value()); } } else { return imperative::apply(add_axis, inputs); } } std::vector remote_recv_rule( const RemoteRecv& remote_recv, Span inputs) { if (remote_recv.shape.empty()) { std::vector shape = {1}; auto remote_recv_no_scalar = RemoteRecv::make( remote_recv.key, remote_recv.addr, remote_recv.port, remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, remote_recv.backend); remote_recv_no_scalar->set_scope(remote_recv.scope()); return imperative::apply( ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); } else { return imperative::apply(remote_recv, unwrap_inputs(inputs)); } } std::vector check_no_finite_rule( const CheckNonFinite& check_no_finite, Span inputs) { auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); outputs.back() = ScalarValue::make(outputs.back()); for (size_t i = 0; i < inputs.size(); ++i) { if (inputs[i].is()) { outputs[i] = ScalarValue::make(outputs[i]); } } return outputs; } std::vector subtensor_rule( const Subtensor& subtensor, Span inputs) { mgb_assert(inputs.size() >= 1); auto input = inputs[0]; bool is_scalar; mgb_assert(!input.is(), "subtensor shouldn't have scalar input"); if (auto shape = input.shape()) { size_t ndim = input.shape()->ndim; for (auto&& [axis, begin, end, step, idx] : subtensor.items) { if (idx) { ndim--; } } is_scalar = ndim == 0; } else { is_scalar = false; } auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; if (is_scalar) { return {ScalarValue::make(output)}; } else { return {output}; } } std::vector get_var_shape_rule( const GetVarShape& get_var_shape, Span inputs) { bool all_scalar = true; mgb_assert(inputs.size() >= 1); for (auto&& input : inputs) { if (!input.is()) { all_scalar = false; } } if (all_scalar) { auto device = inputs[0].cast().value().device(); auto storage = HostStorage::make(*device); // storage->ensure_size(1); return imperative::apply( CreateTensor( CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), storage); } else { return imperative::apply(get_var_shape, unwrap_inputs(inputs)); } } std::vector fastpath_copy_rule( const FastpathCopy& fastpath_copy, Span inputs) { mgb_assert(inputs.size() == 1); bool is_scalar = inputs[0].is(); auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; if (is_scalar) { return {ScalarValue::make(output)}; } else { return {output}; } } std::vector reshape_rule(const Reshape& reshape, Span inputs) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); auto unwrapped_input = inputs[0].is() ? inputs[0].cast().value() : inputs[0]; if (is_scalar) { return {ScalarValue::make(imperative::apply( reshape, unwrapped_input, make_scalar_shape(*unwrapped_input.device()))[0])}; } else { return imperative::apply(reshape, unwrap_inputs(inputs)); } } std::vector broadcast_rule( const Broadcast& broadcast, Span inputs) { mgb_assert(inputs.size() == 2); bool is_scalar = is_scalar_shape(inputs[1]); auto unwrapped_input = inputs[0].is() ? inputs[0].cast().value() : inputs[0]; if (is_scalar) { return {ScalarValue::make(imperative::apply( broadcast, unwrapped_input, make_scalar_shape(*unwrapped_input.device()))[0])}; } else { return imperative::apply(broadcast, unwrap_inputs(inputs)); } } std::vector copy_rule(const Copy& copy, Span inputs) { mgb_assert(inputs.size() == 1); bool is_scalar = inputs[0].is(); if (is_scalar) { return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; } else { return imperative::apply(copy, unwrap_inputs(inputs)); } } std::vector inplace_add_rule( const InplaceAdd& inplace_add, Span inputs) { mgb_assert(inputs.size() == 4); bool is_scalar = inputs[0].is(); if (is_scalar) { return {ScalarValue::make( imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; } else { return imperative::apply(inplace_add, unwrap_inputs(inputs)); } } struct ScalarRuleRegistry { ScalarRuleRegistry() { register_scalar_rule(elemwise_rule); register_scalar_rule(remove_axis_rule); register_scalar_rule(reduce_rule); register_scalar_rule(typecvt_rule); register_scalar_rule(collective_comm_rule); register_scalar_rule(param_pack_split_rule); register_scalar_rule(dot_rule); register_scalar_rule(add_axis_rule); register_scalar_rule(remote_recv_rule); register_scalar_rule(check_no_finite_rule); register_scalar_rule(subtensor_rule); register_scalar_rule(get_var_shape_rule); register_scalar_rule(fastpath_copy_rule); register_scalar_rule(reshape_rule); register_scalar_rule(broadcast_rule); register_scalar_rule(copy_rule); register_scalar_rule(inplace_add_rule); } } _; } // namespace std::vector ScalarTransformation::apply_transformation( const Operator& op, Span inputs) { if (auto apply_op = op.as()) { auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); if (iter != scalar_rules.end()) { return iter->second(apply_op->op(), inputs); } else { // TODO: repeat op return imperative::apply(op, unwrap_inputs(inputs)); } } else if (auto* create_tensor = op.as()) { if (create_tensor->shape().is_scalar()) { ValueShape scalar_shape = {1}; CreateTensor scalar_op( create_tensor->kind(), create_tensor->device(), create_tensor->dtype(), scalar_shape); return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; } else { return imperative::apply(op, inputs); } } else if (auto* get_attr = op.as()) { bool is_scalar = inputs.as_array<1>()[0].is(); auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; if (!is_scalar) { return {output}; } switch (get_attr->attr()) { case GetAttr::Shape: { // Scalar Shape return {ShapeValue::make()}; } case GetAttr::Value: { auto& hv = output.cast(); mgb_assert( hv.shape() == ValueShape({1}), "underlying value should has shape {1}, got %s", hv.shape().to_string().c_str()); return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; } case GetAttr::Data: { auto& dv = output.cast(); mgb_assert( dv.shape() == ValueShape({1}), "underlying value should has shape {1}, got %s", dv.shape().to_string().c_str()); return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; } default: return {output}; } } else if (op.as()) { return {BoolValue::make(inputs.as_array<1>()[0].is())}; } else if (op.is()) { bool is_scalar = inputs.as_array<1>()[0].is(); if (is_scalar) { return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; } else { return imperative::apply(op, inputs); } } else { return imperative::apply(op, unwrap_inputs(inputs)); } }; } // namespace imperative } // namespace mgb