提交 8182af6e 编写于 作者: M Megvii Engine Team

fix(mgb): fix strategy of grad_op and opr_attr

GitOrigin-RevId: bb7ab8fa9dd838cf90841ff8a0222661c9eeff04
上级 70209667
......@@ -70,15 +70,6 @@ def _matmul(
maxdim = dim1 if dim1 > dim2 else dim2
compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode)
Strategy = builtin.ops.MatrixMul.Strategy
strategy = Strategy(0)
if _config._benchmark_kernel:
strategy |= Strategy.PROFILE
else:
strategy |= Strategy.HEURISTIC
if _config._deterministic_kernel:
strategy |= Strategy.REPRODUCIBLE
if dim1 == 1 and dim2 == 1: # dispatch to Dot
(result,) = apply(builtin.Dot(), inp1, inp2)
return result
......
......@@ -621,6 +621,7 @@ def max_pool2d(
pad_h=padding_h,
pad_w=padding_w,
mode="max",
strategy=get_execution_strategy(),
format=conv_format,
)
(output,) = apply(op, inp)
......@@ -665,6 +666,7 @@ def avg_pool2d(
pad_h=padding_h,
pad_w=padding_w,
mode=mode,
strategy=get_execution_strategy(),
format=conv_format,
)
(output,) = apply(op, inp)
......
......@@ -1493,7 +1493,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) {
py::object _matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle profile, py::handle determistic) {
py::handle profile, py::handle deterministic) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
......@@ -1506,7 +1506,7 @@ py::object _matmul_cpp(
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
if (deterministic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = MatrixMul::make(
......@@ -1523,7 +1523,7 @@ py::object _matmul_cpp(
py::object _batched_matmul_cpp(
py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2,
py::handle transpose_a, py::handle transpose_b, py::handle compute_mode,
py::handle profile, py::handle determistic) {
py::handle profile, py::handle deterministic) {
::megdnn::param::MatrixMul::ComputeMode mode =
::megdnn::param::MatrixMul::ComputeMode::DEFAULT;
if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) {
......@@ -1536,7 +1536,7 @@ py::object _batched_matmul_cpp(
} else {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC;
}
if (determistic.cast<bool>()) {
if (deterministic.cast<bool>()) {
cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE;
}
std::shared_ptr<OpDef> op = BatchedMatrixMul::make(
......
......@@ -10,6 +10,10 @@
*/
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/rdnn/profiler.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "../op_trait.h"
......@@ -65,6 +69,42 @@ public:
const serialization::GraphDumpConfig& config() const { mgb_assert(0); }
};
#define cb(FASTRUN_OPR) \
megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR( \
cg::OperatorNodeBase* opr) { \
auto policy = \
opr->cast_final<opr::FASTRUN_OPR>().execution_policy_transient(); \
return policy; \
} \
void set_strategy_##FASTRUN_OPR( \
cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) { \
auto&& p = opr->cast_final<opr::FASTRUN_OPR>(); \
p.set_execution_policy(policy); \
}
DNN_FOREACH_FASTRUN_OPR(cb)
#undef cb
typedef thin_function<megdnn::param::ExecutionPolicy(cg::OperatorNodeBase*)> get_func;
typedef thin_function<void(cg::OperatorNodeBase*, megdnn::param::ExecutionPolicy)>
set_func;
static const mgb::thin_hash_table::ThinHashMap<
mgb::Typeinfo*, std::pair<get_func, set_func>>&
get_type2policy() {
static mgb::thin_hash_table::ThinHashMap<
mgb::Typeinfo*, std::pair<get_func, set_func>>
sl_type2policy;
static std::once_flag flag;
std::call_once(flag, [&]() {
#define cb(FASTRUN_OPR) \
sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \
std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR);
DNN_FOREACH_FASTRUN_OPR(cb)
});
return std::as_const(sl_type2policy);
}
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
......@@ -73,7 +113,12 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, config).usable_output();
auto opr_with_accessor = registry->loader(ctx, inputs, config);
auto&& opr = opr_with_accessor.opr();
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy);
}
return opr_with_accessor.usable_output();
}
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
......@@ -84,7 +129,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
registry->dumper, "operator %s cannot be serialized",
opr->dyn_typeinfo()->name);
registry->dumper(ctx, *opr);
return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
megdnn::param::ExecutionPolicy policy;
if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr);
}
return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config());
}
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
......@@ -108,6 +157,8 @@ OP_TRAIT_REG(OprAttr, OprAttr)
bool OprAttr::is_same_st(const Hashable& rhs_) const {
auto&& rhs = static_cast<const OprAttr&>(rhs_);
return type == rhs.type && param == rhs.param &&
policy.strategy == rhs.policy.strategy &&
policy.workspace_limit == rhs.policy.workspace_limit &&
config.comp_node() == rhs.config.comp_node() &&
config.output_dtype() == rhs.config.output_dtype();
}
......@@ -115,7 +166,12 @@ bool OprAttr::is_same_st(const Hashable& rhs_) const {
size_t OprAttr::hash() const {
return hash_pair_combine(
hash_pair_combine(
mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))),
hash_pair_combine(
mgb::hash(type),
mgb::hash(static_cast<std::vector<char>>(param))),
hash_pair_combine(
static_cast<size_t>(policy.strategy),
policy.workspace_limit)),
config.hash());
}
......
......@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/opr/param_defs.h"
namespace mgb {
namespace imperative {
......@@ -38,12 +39,16 @@ public:
Type type;
Param param;
megdnn::param::ExecutionPolicy policy;
cg::OperatorNodeConfig config;
OprAttr() = default;
OprAttr(const Type& t) : type(t) {}
OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c)
: type(t), param(p), config(c) {}
OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps,
const cg::OperatorNodeConfig& c)
: type(t), param(p), policy(ps), config(c) {}
std::string repr() const;
......
......@@ -157,6 +157,51 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
TEST(TestImperative, ProfileBackward) {
auto cn = CompNode::load("xpux");
using Policy = megdnn::param::ExecutionPolicy;
using S = Policy::Strategy;
Policy policy;
policy.strategy = S::PROFILE;
{
megdnn::param::Convolution param;
auto op = std::shared_ptr<OpDef>(Convolution::make(param, policy));
LogicalTensorDesc inp_desc = {
TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn};
LogicalTensorDesc weight_desc = {
TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn};
auto bg = OpDef::make_backward_graph(
*op, {inp_desc, weight_desc}, {true, false}, {true});
auto&& bop = (bg.graph.exprs.at(0)).op;
auto&& attr = bop->cast_final_safe<OprAttr>();
// attr.type = ConvolutionBackwardDataV2
mgb_assert(attr.policy.strategy == S::PROFILE);
}
{
megdnn::param::Pooling param;
auto op = std::shared_ptr<OpDef>(Pooling::make(param, policy));
LogicalTensorDesc inp_desc = {
TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn};
auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true});
auto&& bop = (bg.graph.exprs.at(0)).op;
auto&& attr = bop->cast_final_safe<OprAttr>();
// attr.type = PoolingBackwardV1
mgb_assert(attr.policy.strategy == S::PROFILE);
}
{
megdnn::param::MatrixMul param;
auto op = std::shared_ptr<OpDef>(MatrixMul::make(param, policy, 2, 2));
LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn};
LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn};
auto bg = OpDef::make_backward_graph(
*op, {inp1_desc, inp2_desc}, {true, false}, {true});
auto&& bop = (bg.graph.exprs.at(0)).op;
auto&& attr = bop->cast_final_safe<OprAttr>();
// attr.type = MatrixMulV2
mgb_assert(attr.policy.strategy == S::PROFILE);
}
}
TEST(TestImperative, BackwardGraphIdentity) {
HostTensorGenerator<> gen;
auto host_a = gen({42}), host_dc = gen({42});
......
......@@ -185,17 +185,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) {
if (wrt_idx == 0) {
// A * B = C, A' = C' * Bt
if (opr.param().transposeA) {
grad = MatrixMul::make(i1, og, {opr.param().transposeB, true});
grad = MatrixMul::make(
i1, og, {opr.param().transposeB, true}, opr.execution_policy());
} else {
grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB});
grad = MatrixMul::make(
og, i1, {false, !opr.param().transposeB}, opr.execution_policy());
}
} else {
mgb_assert(wrt_idx == 1);
// A * B = C, B' = At * C'
if (opr.param().transposeB) {
grad = MatrixMul::make(og, i0, {true, opr.param().transposeA});
grad = MatrixMul::make(
og, i0, {true, opr.param().transposeA}, opr.execution_policy());
} else {
grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false});
grad = MatrixMul::make(
i0, og, {!opr.param().transposeA, false}, opr.execution_policy());
}
}
return grad.node();
......@@ -358,17 +362,21 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
if (wrt_idx == 0) {
// A * B = C, A' = C' * Bt
if (opr.param().transposeA) {
grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true});
grad = BatchedMatrixMul::make(
i1, og, {opr.param().transposeB, true}, opr.execution_policy());
} else {
grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB});
grad = BatchedMatrixMul::make(
og, i1, {false, !opr.param().transposeB}, opr.execution_policy());
}
} else {
mgb_assert(wrt_idx == 1);
// A * B = C, B' = At * C'
if (opr.param().transposeB) {
grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA});
grad = BatchedMatrixMul::make(
og, i0, {true, opr.param().transposeA}, opr.execution_policy());
} else {
grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false});
grad = BatchedMatrixMul::make(
i0, og, {!opr.param().transposeA, false}, opr.execution_policy());
}
}
return grad.node();
......
......@@ -59,7 +59,8 @@ size_t PoolingForward::get_workspace_size_bytes(
MGB_IMPL_OPR_GRAD(PoolingForward) {
mgb_assert(wrt_idx == 0);
SymbolVar grad = PoolingBackward::make(
opr.input(0), opr.output(0), out_grad[0], opr.param());
opr.input(0), opr.output(0), out_grad[0], opr.param(),
opr.execution_policy());
return grad.node();
}
#endif
......
......@@ -26,7 +26,7 @@ namespace opr {
/*!
* \brief matrix_mul(trans0(opr0), trans1(opr1))
*/
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>,
public mixin::AlgoChooserHelper) // {
public:
......@@ -57,7 +57,7 @@ private:
/*!
* \brief batched matrix multiplication on 3D inputs
*/
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
BatchedMatrixMul, intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>,
public mixin::AlgoChooserHelper) // {
public:
......
......@@ -18,7 +18,7 @@
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>,
public mixin::AlgoChooserHelper) // {
public:
......@@ -37,7 +37,7 @@ public:
};
using Pooling = PoolingForward;
MGB_DEFINE_OPR_CLASS(
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>,
public mixin::AlgoChooserHelper) // {
public:
......
......@@ -51,7 +51,7 @@ public:
* Exception would be thrown if execution_policy() has been accessed,
* since it would influence cache and many other decisions.
*/
void set_execution_policy(const ExecutionPolicy& policy);
MGE_WIN_DECLSPEC_FUC void set_execution_policy(const ExecutionPolicy& policy);
/*!
* \brief register a hook to implement custom algo chooser
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册