提交 5798f6ce 编写于 作者: M Megvii Engine Team

feat(subgraph): add OpMeth make_forward_graph

GitOrigin-RevId: 171301fc2be5f867d4d653bc9a3fb22a94c289e6
上级 48db45d1
......@@ -100,6 +100,19 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
return def.trait()->props(def);
}
EncodedSubraph OpDef::make_forward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs){
using ForwardGraphCache = OpMethResultCache<EncodedSubraph, SmallVector<bool>, SmallVector<bool>>;
thread_local ForwardGraphCache cache;
decltype(cache)::key_t cache_key{const_cast<OpDef&>(def).shared_from_this(), inputs};
auto iter = cache.find(cache_key);
if (iter == cache.end()) {
iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)}).first;
}
return iter->second;
}
std::string OpDef::to_string() const {
std::string builder = trait()->make_name(*this) + "{";
for (auto&& [name, value]: props(*this)) {
......
......@@ -16,6 +16,7 @@
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/tensor.h"
#include "./op_trait.h"
......@@ -38,24 +39,45 @@ StaticData& static_data() {
return data;
}
void OpMethFallback::impl(ApplyOnPhysicalTensor& func,
void OpMethFallbackByProxyGraph::impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
}
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) {
void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallback::impl(InferOutputMemDesc& func,
void OpMethFallbackByProxyGraph::impl(InferOutputMemDesc& func,
op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallback::impl(InferOutputAttrsFallible& func,
void OpMethFallbackByProxyGraph::impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
}
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) {
void OpMethFallbackByProxyGraph::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(proxy_graph_detail::make_backward_graph);
}
void OpMethFallbackFromSubgraph::impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(subgraph_detail::apply_on_physical_tensor);
}
void OpMethFallbackFromSubgraph::impl(InferOutputMemDesc& func,
op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(subgraph_detail::infer_output_mem_desc);
}
void OpMethFallbackFromSubgraph::impl(ApplyOnVarNode& func,
op_meth_tag::ApplyOnVarNode) {
func.Base::operator=(subgraph_detail::apply_on_var_node);
}
void OpMethFallbackFromSubgraph::impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(subgraph_detail::infer_output_attrs_fallible);
}
void OpMethFallbackFromSubgraph::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(subgraph_detail::make_backward_graph);
}
void OpMethFallback::impl(DecideDispatchMode& func,
op_meth_tag::DecideDispatchMode) {
static auto decide_dispatch_mode =
......@@ -99,16 +121,20 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
OpTraitRegistry& OpTraitRegistry::fallback() {
using Mode = detail::OpMethFallbackMode;
uint64_t mode = Mode::None;
if (trait->make_forward_graph) {
mode |= Mode::FromSubgraph;
}
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
trait->apply_on_physical_tensor.allow_fallback = true;
trait->execute.allow_fallback = true;
trait->infer_output_mem_desc.allow_fallback = true;
trait->infer_output_attrs_fallible.allow_fallback = true;
trait->make_backward_graph.allow_fallback = true;
mode |= Mode::ByProxyGraph;
}
trait->decide_dispatch_mode.allow_fallback = true;
trait->make_name.allow_fallback = true;
mode |= Mode::Default;
#define SET_FALLBACK_MODE(meth) \
trait->meth.fallback_mode = mode;
FOR_EACH_OP_METH(SET_FALLBACK_MODE)
#undef SET_FALLBACK_MODE
return *this;
}
......
......@@ -95,9 +95,18 @@ OpMethType(IsSame,
OpMethType(MakeNameFunc,
std::string(const OpDef&));
OpMethType(GraphMaker,
decltype(OpDef::make_forward_graph));
// clang-format on
namespace detail {
struct OpMethImplBase {
template <typename Tag, typename RType, typename... Args>
static void impl(thin_function<RType(Args...)>& func, Tag) {}
};
struct OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
static void impl(thin_function<RType(Args...)>& func, Tag) {
......@@ -106,8 +115,15 @@ struct OpMethNotImpl {
};
}
};
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
struct OpMethFallback: OpMethImplBase {
using OpMethImplBase::impl;
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode);
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc);
};
struct OpMethFallbackByProxyGraph: OpMethImplBase {
using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
......@@ -115,18 +131,48 @@ struct OpMethFallback : public OpMethNotImpl {
static void impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode);
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc);
};
struct OpMethFallbackFromSubgraph: OpMethImplBase {
using OpMethImplBase::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode);
static void impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
};
struct OpMethFallbackMode {
static constexpr uint64_t None = 0;
static constexpr uint64_t Default = 1;
static constexpr uint64_t ByProxyGraph = 2;
static constexpr uint64_t FromSubgraph = 4;
};
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
OpMeth() : Base{}, allow_fallback(false){};
OpMeth() : Base{}{};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
using Base::operator bool;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {
uint64_t mode_mask = ~uint64_t(0);
auto match_mode = [&](uint64_t mode){
if ((fallback_mode & mode_mask) & mode) {
mode_mask &= ~mode;
return true;
}
return false;
};
while (!this->Base::operator bool()) {
using Mode = OpMethFallbackMode;
if (match_mode(Mode::FromSubgraph)) {
OpMethFallbackFromSubgraph::impl(*const_cast<OpMeth*>(this), Tag{});
} else if (match_mode(Mode::ByProxyGraph)) {
OpMethFallbackByProxyGraph::impl(*const_cast<OpMeth*>(this), Tag{});
} else if (match_mode(Mode::Default)) {
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{});
} else {
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{});
......@@ -134,7 +180,7 @@ struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
}
return this->Base::operator()(std::forward<Args>(args)...);
}
bool allow_fallback = false;
uint64_t fallback_mode = OpMethFallbackMode::None;
};
} // namespace detail
......@@ -153,6 +199,7 @@ struct OpTrait {
HashFunc hash;
IsSame is_same_st;
MakeNameFunc make_name;
GraphMaker make_forward_graph;
OpTrait(const char* name);
static OpTrait* find_by_name(const char* name);
static OpTrait* find_by_typeinfo(Typeinfo* type);
......@@ -173,7 +220,9 @@ struct OpTrait {
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_name)
cb(make_name) \
cb(make_forward_graph) \
// clang-format on
struct OpTraitRegistry {
......
/**
* \file imperative/src/impl/subgraph_detail.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/opr/io.h"
#include "megbrain/imperative/ops/autogen.h"
#include "./op_trait.h"
namespace mgb {
namespace imperative {
namespace subgraph_detail {
VarNodeArray apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input: inputs) {
input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()});
}
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){
return OpDef::apply_on_var_node(*op, inputs);
};
auto const_functor = [&](const TensorPtr& value) {
return opr::ImmutableTensor::make(*inputs[0]->owner_graph(), value->get_value()).node();
};
auto subgraph = def.trait()->make_forward_graph(def, input_descs);
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
return outputs;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
auto subgraph = def.trait()->make_forward_graph(def, inputs);
bool all_validated = true;
auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const SmallVector<LogicalTensorDesc>& inputs, size_t nr_outputs){
auto [outputs, validated] = OpDef::infer_output_attrs_fallible(*op, inputs);
all_validated = all_validated && validated;
return outputs;
};
auto const_functor = [&](const TensorPtr& value) {
return LogicalTensorDesc{value->layout(), value->comp_node(), value->get_value().proxy_to_default_cpu()};
};
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
return { outputs, all_validated };
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def,
SmallVector<TensorPtr> inputs) {
SmallVector<LogicalTensorDesc> input_descs;
for (auto&& input: inputs) {
input_descs.push_back({input->layout(), input->comp_node()});
}
auto subgraph = def.trait()->make_forward_graph(def, input_descs);
auto apply_functor = [](const std::shared_ptr<OpDef>& op, const SmallVector<TensorPtr>& inputs, size_t nr_outputs){
return OpDef::apply_on_physical_tensor(*op, inputs);
};
auto const_functor = [&](const TensorPtr& value) {
return value;
};
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
return outputs;
}
static EncodedSubraph make_backward_graph_from_forward(
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad,
EncodedSubraph forward_graph) {
using namespace std::placeholders;
using var_t = Subgraph::var_t;
using vars_t = Subgraph::vars_t;
Subgraph::Builder<LogicalTensorDesc> builder([](auto&& op, auto&& input_descs, size_t nr_outputs){
auto [descs, _] = OpDef::infer_output_attrs_fallible(*op, input_descs);
return descs;
});
auto accum_grad = [&](var_t lhs, var_t rhs) {
return builder.write_expr(Elemwise::make(Elemwise::Mode::ADD), {lhs, rhs}, 1)[0];
};
GradContext<var_t> grad_context{accum_grad};
auto input_vars = builder.write_inputs(inputs);
auto outputs = forward_graph.apply(input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), [&](TensorPtr constant){
return builder.write_constant(constant, {constant->layout(), constant->comp_node()});
});
size_t nr_outputs = outputs.size();
auto apply_mask = [](auto&& values, SmallVector<bool> mask) {
mgb_assert(mask.size() == values.size(), "");
std::decay_t<decltype(values)> results;
for (size_t i = 0; i < mask.size(); ++i) {
if (mask[i]) {
results.push_back(values[i]);
}
}
return results;
};
grad_context.mark_require_grads(apply_mask(input_vars, input_requires_grad));
builder.iterate([&](std::list<Subgraph::expr_t>::iterator iter){
grad_context.record_expr(iter->op, iter->inputs, iter->outputs);
});
auto output_descs = builder.get_descs(outputs);
auto computed_outputs = builder.write_inputs(output_descs);
auto output_grads = builder.write_inputs(output_descs);
grad_context.backward(
apply_mask(outputs, output_has_grad),
apply_mask(output_grads, output_has_grad),
[&](Subgraph::expr_t expr, vars_t output_grads) {
auto bg = OpDef::make_backward_graph(
*expr.op, builder.get_descs(expr.inputs),
grad_context.get_require_grads(expr.inputs),
grad_context.get_has_grads(expr.outputs));
if (bg.graph.empty()) {
return vars_t(expr.inputs.size(), 0);
}
vars_t grad_inputs;
grad_inputs.insert(grad_inputs.end(), expr.inputs.begin(),
expr.inputs.end());
grad_inputs.insert(grad_inputs.end(), expr.outputs.begin(),
expr.outputs.end());
grad_inputs.insert(grad_inputs.end(), output_grads.begin(),
output_grads.end());
auto apply_functor = std::bind(&decltype(builder)::write_expr,
&builder, _1, _2, _3);
auto const_functor = [&](TensorPtr constant) {
return builder.write_constant(constant, {constant->layout(),
constant->comp_node()});
};
return bg.apply(grad_inputs, apply_functor, const_functor);
});
builder.add_outputs(grad_context.get_grads(input_vars));
for (size_t i = 0; i < nr_outputs; ++i) {
builder.replace_var(outputs[i], computed_outputs[i]);
}
auto backward_graph = builder.encode();
return backward_graph;
}
EncodedSubraph make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
auto forward_graph = OpDef::make_forward_graph(def, inputs);
return make_backward_graph_from_forward(inputs, input_requires_grad, output_has_grad, forward_graph);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
}
}
}
......@@ -13,6 +13,7 @@
#include "megbrain/graph.h"
#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/to_string.h"
#include "megbrain/imperative/subgraph.h"
......@@ -94,6 +95,10 @@ public:
static std::vector<std::pair<const char*, std::string>> props(
const OpDef& def);
static EncodedSubraph make_forward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
const OpTrait* trait() const;
std::string to_string() const;
......
/**
* \file imperative/src/include/megbrain/imperative/subgraph_detail.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/op_def.h"
namespace mgb {
namespace imperative {
namespace subgraph_detail {
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
SmallVector<TensorPtr> inputs);
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
EncodedSubraph
make_backward_graph(const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
cg::VarNodeArray
apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs);
EncodedSubraph make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems);
}
}
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册