From e32929dfd274c6f14ff47acab025e284a769f53a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team <megengine@megvii.com> Date: Fri, 14 Jan 2022 13:22:35 +0800 Subject: [PATCH] refactor(dispatch): implement scalar GitOrigin-RevId: b244c2ca1ad5cb28ffcf0e320cd0440f298bea51 --- .../src/impl/transformations/scalar.cpp | 404 ++++++++++++++++++ .../imperative/transformations/scalar.h | 60 +++ 2 files changed, 464 insertions(+) create mode 100644 imperative/src/impl/transformations/scalar.cpp create mode 100644 imperative/src/include/megbrain/imperative/transformations/scalar.h diff --git a/imperative/src/impl/transformations/scalar.cpp b/imperative/src/impl/transformations/scalar.cpp new file mode 100644 index 000000000..891abdd7c --- /dev/null +++ b/imperative/src/impl/transformations/scalar.cpp @@ -0,0 +1,404 @@ +/** + * \file imperative/src/impl/transformations/trace.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/transformations/scalar.h" + +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb { +namespace imperative { + +namespace { + +using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>; +static std::unordered_map< + Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>> + scalar_rules; + +ValueRef unwrap_input(ValueRef input) { + if (auto scalar_input = input.as_ref<ScalarValue>()) { + return scalar_input->value(); + } else { + return input; + } +} + +std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) { + std::vector<ValueRef> unwrapped_inputs; + for (auto&& input : inputs) { + unwrapped_inputs.push_back(unwrap_input(input)); + } + return unwrapped_inputs; +} + +ValueRef make_scalar_shape(CompNode device) { + HostTensorND scalar_shape(device, {1}, dtype::Int32()); + scalar_shape.ptr<dt_int32>()[0] = 1; + return imperative::apply( + CreateTensor(CreateTensor::Const, device, scalar_shape.layout()), + HostStorage::make(scalar_shape.storage()))[0]; +} + +bool is_scalar_shape(ValueRef shape) { + if (shape.is<ScalarValue>()) { + return false; + } + auto shape_of_shape = shape.shape(); + if (!shape_of_shape) { + // assume not scalar + return false; + } + return *shape_of_shape == ValueShape{0}; +} + +template <typename T> +void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) { + scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) { + return (*rule)(def.cast_final_safe<T>(), inputs); + }; +} + +std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) { + bool all_scalar = true; + for (auto&& input : inputs) { + if (!input.is<ScalarValue>()) { + all_scalar = false; + break; + } + } + auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; + if (all_scalar) { + return {ScalarValue::make(output)}; + } else { + return {output}; + } +} + +std::vector<ValueRef> remove_axis_rule( + const RemoveAxis& remove_axis, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + mgb_assert(!inputs[0].is<ScalarValue>()); + auto output = imperative::apply(remove_axis, inputs)[0]; + bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); + if (is_scalar) { + return {ScalarValue::make(output)}; + } else { + return {output}; + } +} + +std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) { + if (inputs.size() == 1) { + return imperative::apply(reduce, unwrap_inputs(inputs)); + } + mgb_assert(inputs.size() == 2); + bool is_scalar = is_scalar_shape(inputs[1]); + if (is_scalar) { + auto unwrapped_input = unwrap_input(inputs[0]); + CompNode device = *unwrapped_input.device(); + return {ScalarValue::make(imperative::apply( + reduce, unwrapped_input, make_scalar_shape(device))[0])}; + } + auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; + if (is_scalar) { + return {ScalarValue::make(output)}; + } else { + return {output}; + } +} + +std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { + return {ScalarValue::make( + imperative::apply(typecvt, scalar_input->value())[0])}; + } else { + return imperative::apply(typecvt, inputs); + } +} + +std::vector<ValueRef> collective_comm_rule( + const CollectiveComm& collective_comm, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + static std::unordered_set<CollectiveComm::Mode> modes = { + CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, + CollectiveComm::Mode::ALL_REDUCE_SUM, CollectiveComm::Mode::BROADCAST, + CollectiveComm::Mode::REDUCE_SUM, + }; + if (modes.count(collective_comm.mode) == 0) { + return imperative::apply(collective_comm, inputs); + } + if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { + return {ScalarValue::make( + imperative::apply(collective_comm, scalar_input->value())[0])}; + } else { + return imperative::apply(collective_comm, inputs); + } +} + +std::vector<ValueRef> param_pack_split_rule( + const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) { + auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); + size_t nr_outputs = outputs.size(); + mgb_assert(nr_outputs == param_pack_split.shapes.size()); + for (size_t i = 0; i < nr_outputs; ++i) { + if (param_pack_split.shapes[i].empty()) { + outputs[i] = ScalarValue::make(outputs[i]); + } + } + return outputs; +} + +std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) { + return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; +} + +std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { + mgb_assert(add_axis.axis[0] == 0); + if (add_axis.axis.size() == 1) { + return {scalar_input->value()}; + } else { + std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); + return imperative::apply( + ApplyOp(*AddAxis::make(axis, add_axis.scope())), + scalar_input->value()); + } + } else { + return imperative::apply(add_axis, inputs); + } +} + +std::vector<ValueRef> remote_recv_rule( + const RemoteRecv& remote_recv, Span<ValueRef> inputs) { + if (remote_recv.shape.empty()) { + std::vector<int32_t> shape = {1}; + auto remote_recv_no_scalar = RemoteRecv::make( + remote_recv.key, remote_recv.addr, remote_recv.port, + remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, + remote_recv.backend); + remote_recv_no_scalar->set_scope(remote_recv.scope()); + return imperative::apply( + ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); + } else { + return imperative::apply(remote_recv, unwrap_inputs(inputs)); + } +} + +std::vector<ValueRef> check_no_finite_rule( + const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) { + auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); + mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); + outputs.back() = ScalarValue::make(outputs.back()); + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].is<ScalarValue>()) { + outputs[i] = ScalarValue::make(outputs[i]); + } + } + return outputs; +} + +std::vector<ValueRef> subtensor_rule( + const Subtensor& subtensor, Span<ValueRef> inputs) { + mgb_assert(inputs.size() >= 1); + auto input = inputs[0]; + size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim; + for (auto&& [axis, begin, end, step, idx] : subtensor.items) { + if (idx) { + ndim--; + } + } + auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; + if (!ndim) { + return {ScalarValue::make(output)}; + } else { + return {output}; + } +} + +std::vector<ValueRef> get_var_shape_rule( + const GetVarShape& get_var_shape, Span<ValueRef> inputs) { + bool all_scalar = true; + mgb_assert(inputs.size() >= 1); + for (auto&& input : inputs) { + if (!input.is<ScalarValue>()) { + all_scalar = false; + } + } + if (all_scalar) { + auto device = inputs[0].cast<ScalarValue>().value().device(); + auto storage = HostStorage::make(*device); + // storage->ensure_size(1); + return imperative::apply( + CreateTensor( + CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), + storage); + } else { + return imperative::apply(get_var_shape, unwrap_inputs(inputs)); + } +} + +std::vector<ValueRef> fastpath_copy_rule( + const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + bool is_scalar = inputs[0].is<ScalarValue>(); + auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; + if (is_scalar) { + return {ScalarValue::make(output)}; + } else { + return {output}; + } +} + +std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 2); + bool is_scalar = + (!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0}; + auto unwrapped_input = inputs[0].is<ScalarValue>() + ? inputs[0].cast<ScalarValue>().value() + : inputs[0]; + if (is_scalar) { + return {ScalarValue::make(imperative::apply( + reshape, unwrapped_input, + make_scalar_shape(*unwrapped_input.device()))[0])}; + } else { + return imperative::apply(reshape, unwrap_inputs(inputs)); + } +} + +std::vector<ValueRef> broadcast_rule( + const Broadcast& broadcast, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 2); + bool is_scalar = is_scalar_shape(inputs[1]); + auto unwrapped_input = inputs[0].is<ScalarValue>() + ? inputs[0].cast<ScalarValue>().value() + : inputs[0]; + if (is_scalar) { + return {ScalarValue::make(imperative::apply( + broadcast, unwrapped_input, + make_scalar_shape(*unwrapped_input.device()))[0])}; + } else { + return imperative::apply(broadcast, unwrap_inputs(inputs)); + } +} + +std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 1); + bool is_scalar = inputs[0].is<ScalarValue>(); + if (is_scalar) { + return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; + } else { + return imperative::apply(copy, unwrap_inputs(inputs)); + } +} + +std::vector<ValueRef> inplace_add_rule( + const InplaceAdd& inplace_add, Span<ValueRef> inputs) { + mgb_assert(inputs.size() == 4); + bool is_scalar = inputs[0].is<ScalarValue>(); + if (is_scalar) { + return {ScalarValue::make( + imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; + } else { + return imperative::apply(inplace_add, unwrap_inputs(inputs)); + } +} + +struct ScalarRuleRegistry { + ScalarRuleRegistry() { + register_scalar_rule(elemwise_rule); + register_scalar_rule(remove_axis_rule); + register_scalar_rule(reduce_rule); + register_scalar_rule(typecvt_rule); + register_scalar_rule(collective_comm_rule); + register_scalar_rule(param_pack_split_rule); + register_scalar_rule(dot_rule); + register_scalar_rule(add_axis_rule); + register_scalar_rule(remote_recv_rule); + register_scalar_rule(check_no_finite_rule); + register_scalar_rule(subtensor_rule); + register_scalar_rule(get_var_shape_rule); + register_scalar_rule(fastpath_copy_rule); + register_scalar_rule(reshape_rule); + register_scalar_rule(broadcast_rule); + register_scalar_rule(copy_rule); + register_scalar_rule(inplace_add_rule); + } +} _; +} // namespace + +std::vector<ValueRef> ScalarTransformation::apply_transformation( + const Operator& op, Span<ValueRef> inputs) { + if (auto apply_op = op.as<ApplyOp>()) { + auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); + if (iter != scalar_rules.end()) { + return iter->second(apply_op->op(), inputs); + } else { + // TODO: repeat op + return imperative::apply(op, unwrap_inputs(inputs)); + } + } else if (auto* create_tensor = op.as<CreateTensor>()) { + if (create_tensor->shape().is_scalar()) { + ValueShape scalar_shape = {1}; + CreateTensor scalar_op( + create_tensor->kind(), create_tensor->device(), + create_tensor->dtype(), scalar_shape); + return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; + } else { + return imperative::apply(op, inputs); + } + } else if (auto* get_attr = op.as<GetAttr>()) { + bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); + auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; + if (!is_scalar) { + return {output}; + } + switch (get_attr->attr()) { + case GetAttr::Shape: { + // Scalar Shape + return {ShapeValue::make()}; + } + case GetAttr::Value: { + auto& hv = output.cast<HostValue>(); + mgb_assert( + hv.shape() == ValueShape({1}), + "underlying value should has shape {1}, got %s", + hv.shape().to_string().c_str()); + return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; + } + case GetAttr::Data: { + auto& dv = output.cast<DeviceValue>(); + mgb_assert( + dv.shape() == ValueShape({1}), + "underlying value should has shape {1}, got %s", + dv.shape().to_string().c_str()); + return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; + } + default: + return {output}; + } + } else if (op.as<IsScalar>()) { + return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())}; + } else if (op.is<Operator::IdentityLike>()) { + bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); + if (is_scalar) { + return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; + } else { + return imperative::apply(op, inputs); + } + } else { + return imperative::apply(op, unwrap_inputs(inputs)); + } +}; + +} // namespace imperative +} // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/scalar.h b/imperative/src/include/megbrain/imperative/transformations/scalar.h new file mode 100644 index 000000000..a56fa093d --- /dev/null +++ b/imperative/src/include/megbrain/imperative/transformations/scalar.h @@ -0,0 +1,60 @@ +/** + * \file imperative/src/include/megbrain/imperative/scalar.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/ops/autogen.h" + +namespace mgb::imperative { + +class ScalarValue final : public ValueImpl<ScalarValue> { +private: + ValueRef m_value; + +public: + ScalarValue(ValueRef value) : m_value(value) {} + + std::string to_string() const override { + return ssprintf("ScalarValue{value=%s}", m_value.to_string().c_str()); + } + + ValueRef value() const { return m_value; } + + void clear() override { m_value = {}; } + + void on_watch() override { m_value.watch(); } + + void on_unwatch() override { m_value.unwatch(); } +}; + +/** + * \brief simulates scalar because megbrain graph system don't support scalar + * + * Assume that we has 'a = ScalarValue(b)', thus 'a.shape == []', 'b.shape == [1]'. + * This transformation simulates scalars with a flag. If a value is ScalarValue, it is + * scalar, vice versa. So there is not scalar down this layer. + */ +class ScalarTransformation final : public Transformation { +private: +public: + std::vector<ValueRef> apply_transformation( + const Operator& op, Span<ValueRef> inputs) override; + + ValueRef unwrap(ValueRef value) override { + mgb_assert(!value.is<ScalarValue>()); + return value; + } + + std::string name() const override { return "ScalarTransformation"; } +}; + +} // namespace mgb::imperative -- GitLab