提交 7f280058 编写于 作者: M Megvii Engine Team

fix(imperative): fix the matrix mul error when symbolic trace

GitOrigin-RevId: 3754dc5a658718a4b00bd16d58e581116574adb7
上级 485f0a1f
...@@ -12,6 +12,7 @@ import megengine.optimizer as optim ...@@ -12,6 +12,7 @@ import megengine.optimizer as optim
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.jit import trace from megengine.jit import trace
from megengine.optimizer import SGD
@contextlib.contextmanager @contextlib.contextmanager
...@@ -138,3 +139,36 @@ def test_dump_bn_train_mode(): ...@@ -138,3 +139,36 @@ def test_dump_bn_train_mode():
bn_train(data) bn_train(data)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
bn_train.dump("test.mge") bn_train.dump("test.mge")
class ViTmode(M.Module):
def __init__(self, patch_size=16, in_chans=3, embed_dim=384):
super().__init__()
self.proj = M.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.head = M.Linear(embed_dim, 1000)
def forward(self, x):
x = self.proj(x)
x = F.flatten(x, 2).transpose(0, 2, 1)
x = self.head(x)
return x
def test_ViTmode_trace_train():
model = ViTmode(embed_dim=384)
data = mge.random.normal(size=(1, 3, 224, 224))
optim = SGD(model.parameters(), lr=0.01)
gm = GradManager()
gm.attach(model.parameters())
@trace(symbolic=True, capture_as_const=True)
def train():
for i in range(2):
with gm:
loss = model(data)
gm.backward(loss)
optim.step().clear_grad()
train()
...@@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB; auto dim1 = matmul.dimA, dim2 = matmul.dimB;
mgb_assert(
dim1 >= 2 && dim2 >= 2,
"the dim of one input the matmul operator dim is less than 2.");
auto cn = inputs[0]->comp_node(); auto cn = inputs[0]->comp_node();
using IndexDesc = opr::Subtensor::IndexDesc; using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn}; OperatorNodeConfig config{matmul.make_name(), cn};
DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph(); auto graph = inputs[0]->owner_graph();
if (dim1 == 2 && dim2 == 2) {
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; return opr::MatrixMul::make(
if (dim1 > 2) { inp1, inp2, matmul.param(), matmul.policy(), config);
}
//! use batched matrix mul
SymbolVar shp_head, batch;
DTypeScalar vi{-2};
auto compress_shape = [&](SymbolVar inp) {
if (inp.shape().ndim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config); auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp1 = inp1.symshape(); auto shp = inp.symshape();
IndexDesc head_desc(1); IndexDesc head_desc(1);
head_desc[0].end = idx; head_desc[0].end = idx;
shp1_head = opr::Subtensor::make(shp1, head_desc); shp_head = opr::Subtensor::make(shp, head_desc);
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1); IndexDesc tail_desc(1);
tail_desc[0].begin = idx; tail_desc[0].begin = idx;
shp1_tail = opr::Subtensor::make(shp1, tail_desc); auto shp_tail = opr::Subtensor::make(shp, tail_desc);
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn);
inp1 = inp1.reshape(tshp); return inp.reshape(tshp);
} } else if (inp.shape().ndim == 3) {
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config); auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp2 = inp2.symshape(); auto shp = inp.symshape();
IndexDesc head_desc(1); IndexDesc head_desc(1);
head_desc[0].end = idx; head_desc[0].end = idx;
shp2_head = opr::Subtensor::make(shp2, head_desc); shp_head = opr::Subtensor::make(shp, head_desc);
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1); return inp;
tail_desc[0].begin = idx; } else {
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); return inp;
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(tshp);
} }
auto result = };
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) { inp1 = compress_shape(inp1);
auto idx = opr::ImmutableTensor::make(*graph, vi, config); inp2 = compress_shape(inp2);
auto result_shape = result.symshape();
IndexDesc tail_desc(1); auto expand_shape = [&](SymbolVar inp) {
tail_desc[0].begin = idx; if (inp.shape().ndim < 3) {
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); auto shp = inp.symshape();
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); using Desc = opr::AxisAddRemove::AxisDesc;
result = result.reshape(tshp); std::vector<Desc> add_axis_param;
add_axis_param.push_back(Desc::make_add(0));
auto out = opr::AxisAddRemove::make(inp, add_axis_param);
auto target_shape = opr::Concat::make({batch, shp}, 0, cn);
return opr::Broadcast::make(out, target_shape);
} else {
return inp;
} }
if (dim2 > 2) { };
inp1 = expand_shape(inp1);
inp2 = expand_shape(inp2);
auto result = opr::BatchedMatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
size_t max_dim = std::max(dim1, dim2);
if (max_dim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config); auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape(); auto res_shp = result.symshape();
IndexDesc tail_desc(1); IndexDesc tail_desc(1);
tail_desc[0].begin = idx; tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); auto tail_shape = opr::Subtensor::make(res_shp, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn);
result = result.reshape(tshp); result = result.reshape(tshp);
} }
return result; return result;
} }
......
...@@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) { ...@@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
} }
} }
VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) {
HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps);
VarNodeArray sym_inp(nr_inps);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
for (size_t i = 0; i < nr_inps; ++i) {
// TODO: remove std::visit for support osx 10.12
host_inp[i] = std::visit(
[&gen](auto&& arg) -> HostTensorND {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<TensorShape, T>) {
return *gen(arg);
} else {
static_assert(std::is_same_v<HostTensorND, T>);
return arg;
}
},
inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
return OpDef::apply_on_var_node(*m_op, sym_inp);
}
TEST(TestHelper, PyModule) { TEST(TestHelper, PyModule) {
py::module m = PyEnv::get(); py::module m = PyEnv::get();
py::print(m); py::print(m);
......
...@@ -14,6 +14,9 @@ public: ...@@ -14,6 +14,9 @@ public:
OprChecker(std::shared_ptr<OpDef> opdef); OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {}); void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {});
//! test the interface of apply_on_var_node
VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes);
private: private:
std::shared_ptr<OpDef> m_op; std::shared_ptr<OpDef> m_op;
}; };
......
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
...@@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) { ...@@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) {
} }
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION
TEST(TestImperative, MatrixMulApplyOnVarNode) {
using Param = opr::MatrixMul::Param;
Param param;
std::vector<std::pair<TensorShape, TensorShape>> shapes;
std::vector<TensorShape> target_shapes;
std::vector<Param> params;
//! testcase 0
params.push_back(param);
shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{10, 10});
//! testcase 1
params.push_back(param);
shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{3, 10, 10});
//! testcase 2
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{3, 6, 10});
//! testcase 3
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{2, 3, 6, 10});
//! testcase 4
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});
//! testcase 5
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});
for (size_t i = 0; i < params.size(); i++) {
auto& shape = shapes[i];
auto op = MatrixMul::make(
params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim,
shape.second.ndim);
auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second});
ASSERT_GT(result.size(), 0);
ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim);
for (size_t id = 0; id < target_shapes[i].ndim; id++) {
ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册