提交 f16e4311 编写于 作者: M Megvii Engine Team

fix(imperative): fix the matrix mul error when symbolic trace

GitOrigin-RevId: ed2424d101319b0ec7d8e140995e5168e2be8b8f
上级 67327685
......@@ -12,6 +12,7 @@ import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
from megengine.optimizer import SGD
@contextlib.contextmanager
......@@ -138,3 +139,50 @@ def test_dump_bn_train_mode():
bn_train(data)
with pytest.raises(RuntimeError):
bn_train.dump("test.mge")
class ViT(M.Module):
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = M.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.extra = M.Linear(embed_dim, embed_dim)
self.head = M.Linear(embed_dim, 1)
def forward(self, x):
x = self.proj(x)
x = F.flatten(x, 2).transpose(0, 2, 1)
x = self.extra(x)
x = x.mean(axis=1)
x = self.head(x)
x = x.sum()
return x
def test_ViTmode_trace_train():
model = ViT(embed_dim=384)
data = mge.random.normal(size=(1, 3, 224, 224))
optim = SGD(model.parameters(), lr=0.01)
gm = GradManager()
gm.attach(model.parameters())
@trace(symbolic=True)
def train(d):
with gm:
loss = model(d)
gm.backward(loss)
optim.step().clear_grad()
return loss
@trace(symbolic=True)
def val(d):
loss = model(d)
return loss
for i in range(3):
print(f"iter: {i}")
t = train(data)
r = val(data)
......@@ -59,26 +59,26 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
}
auto result =
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
if (dim2 > 2) {
//! dim1 > 2 && dim2 > 2: (n, c, o) * (n, c, hw) ===>(nxc, o) * (nxc, hw)=(hw, o)
//! dim1 == 2 && dim2 == 2: (c, o) * (c, hw) --> (hw, o)
if ((dim1 > 2 && dim2 > 2) || (dim1 == dim2 && dim1 == 2)) {
return result;
} else {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
if (dim1 > 2) {
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
} else {
mgb_assert(dim2 > 2);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
return result;
}
return result;
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
......
......@@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
}
}
VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) {
HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps);
VarNodeArray sym_inp(nr_inps);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
for (size_t i = 0; i < nr_inps; ++i) {
// TODO: remove std::visit for support osx 10.12
host_inp[i] = std::visit(
[&gen](auto&& arg) -> HostTensorND {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<TensorShape, T>) {
return *gen(arg);
} else {
static_assert(std::is_same_v<HostTensorND, T>);
return arg;
}
},
inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
return OpDef::apply_on_var_node(*m_op, sym_inp);
}
TEST(TestHelper, PyModule) {
py::module m = PyEnv::get();
py::print(m);
......
......@@ -21,6 +21,9 @@ public:
OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {});
//! test the interface of apply_on_var_node
VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes);
private:
std::shared_ptr<OpDef> m_op;
};
......
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
......@@ -164,4 +166,45 @@ TEST(TestImperative, Defragment) {
}
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION
TEST(TestImperative, MatrixMulApplyOnVarNode) {
using Param = opr::MatrixMul::Param;
Param param;
std::vector<std::pair<TensorShape, TensorShape>> shapes;
std::vector<TensorShape> target_shapes;
std::vector<Param> params;
//! testcase 0
params.push_back(param);
shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{10, 10});
params.push_back(param);
shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{3, 10, 10});
//! testcase 2
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{3, 7, 6}, TensorShape{3, 7, 10}});
target_shapes.push_back(TensorShape{6, 10});
//! testcase 3
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{10, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 10});
for (size_t i = 0; i < params.size(); i++) {
auto& shape = shapes[i];
auto op = MatrixMul::make(
params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim,
shape.second.ndim);
auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second});
ASSERT_GT(result.size(), 0);
ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim);
for (size_t id = 0; id < target_shapes[i].ndim; id++) {
ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册