提交 7f280058 编写于 作者: M Megvii Engine Team

fix(imperative): fix the matrix mul error when symbolic trace

GitOrigin-RevId: 3754dc5a658718a4b00bd16d58e581116574adb7
上级 485f0a1f
......@@ -12,6 +12,7 @@ import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
from megengine.optimizer import SGD
@contextlib.contextmanager
......@@ -138,3 +139,36 @@ def test_dump_bn_train_mode():
bn_train(data)
with pytest.raises(RuntimeError):
bn_train.dump("test.mge")
class ViTmode(M.Module):
def __init__(self, patch_size=16, in_chans=3, embed_dim=384):
super().__init__()
self.proj = M.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.head = M.Linear(embed_dim, 1000)
def forward(self, x):
x = self.proj(x)
x = F.flatten(x, 2).transpose(0, 2, 1)
x = self.head(x)
return x
def test_ViTmode_trace_train():
model = ViTmode(embed_dim=384)
data = mge.random.normal(size=(1, 3, 224, 224))
optim = SGD(model.parameters(), lr=0.01)
gm = GradManager()
gm.attach(model.parameters())
@trace(symbolic=True, capture_as_const=True)
def train():
for i in range(2):
with gm:
loss = model(data)
gm.backward(loss)
optim.step().clear_grad()
train()
......@@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 2);
auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]};
auto dim1 = matmul.dimA, dim2 = matmul.dimB;
mgb_assert(
dim1 >= 2 && dim2 >= 2,
"the dim of one input the matmul operator dim is less than 2.");
auto cn = inputs[0]->comp_node();
using IndexDesc = opr::Subtensor::IndexDesc;
OperatorNodeConfig config{matmul.make_name(), cn};
DTypeScalar vi{-1};
auto graph = inputs[0]->owner_graph();
SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail;
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp1 = inp1.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp1_head = opr::Subtensor::make(shp1, head_desc);
auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
shp1_tail = opr::Subtensor::make(shp1, tail_desc);
auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn);
inp1 = inp1.reshape(tshp);
}
if (dim2 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp2 = inp2.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp2_head = opr::Subtensor::make(shp2, head_desc);
auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp2_tail = opr::Subtensor::make(shp2, tail_desc);
auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn);
inp2 = inp2.reshape(tshp);
if (dim1 == 2 && dim2 == 2) {
return opr::MatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
}
auto result =
opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config);
if (dim1 > 2) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn);
result = result.reshape(tshp);
}
if (dim2 > 2) {
//! use batched matrix mul
SymbolVar shp_head, batch;
DTypeScalar vi{-2};
auto compress_shape = [&](SymbolVar inp) {
if (inp.shape().ndim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(shp, tail_desc);
auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn);
return inp.reshape(tshp);
} else if (inp.shape().ndim == 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto shp = inp.symshape();
IndexDesc head_desc(1);
head_desc[0].end = idx;
shp_head = opr::Subtensor::make(shp, head_desc);
batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0});
return inp;
} else {
return inp;
}
};
inp1 = compress_shape(inp1);
inp2 = compress_shape(inp2);
auto expand_shape = [&](SymbolVar inp) {
if (inp.shape().ndim < 3) {
auto shp = inp.symshape();
using Desc = opr::AxisAddRemove::AxisDesc;
std::vector<Desc> add_axis_param;
add_axis_param.push_back(Desc::make_add(0));
auto out = opr::AxisAddRemove::make(inp, add_axis_param);
auto target_shape = opr::Concat::make({batch, shp}, 0, cn);
return opr::Broadcast::make(out, target_shape);
} else {
return inp;
}
};
inp1 = expand_shape(inp1);
inp2 = expand_shape(inp2);
auto result = opr::BatchedMatrixMul::make(
inp1, inp2, matmul.param(), matmul.policy(), config);
size_t max_dim = std::max(dim1, dim2);
if (max_dim > 3) {
auto idx = opr::ImmutableTensor::make(*graph, vi, config);
auto result_shape = result.symshape();
auto res_shp = result.symshape();
IndexDesc tail_desc(1);
tail_desc[0].begin = idx;
auto shp_tail = opr::Subtensor::make(result_shape, tail_desc);
auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn);
auto tail_shape = opr::Subtensor::make(res_shp, tail_desc);
auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn);
result = result.reshape(tshp);
}
return result;
}
......
......@@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) {
}
}
VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) {
HostTensorGenerator<> gen;
size_t nr_inps = inp_keys.size();
SmallVector<HostTensorND> host_inp(nr_inps);
VarNodeArray sym_inp(nr_inps);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
for (size_t i = 0; i < nr_inps; ++i) {
// TODO: remove std::visit for support osx 10.12
host_inp[i] = std::visit(
[&gen](auto&& arg) -> HostTensorND {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<TensorShape, T>) {
return *gen(arg);
} else {
static_assert(std::is_same_v<HostTensorND, T>);
return arg;
}
},
inp_keys[i]);
sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node();
}
return OpDef::apply_on_var_node(*m_op, sym_inp);
}
TEST(TestHelper, PyModule) {
py::module m = PyEnv::get();
py::print(m);
......
......@@ -14,6 +14,9 @@ public:
OprChecker(std::shared_ptr<OpDef> opdef);
void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {});
//! test the interface of apply_on_var_node
VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes);
private:
std::shared_ptr<OpDef> m_op;
};
......
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
......@@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) {
}
#endif // MGB_CUDA && MGB_ENABLE_EXCEPTION
TEST(TestImperative, MatrixMulApplyOnVarNode) {
using Param = opr::MatrixMul::Param;
Param param;
std::vector<std::pair<TensorShape, TensorShape>> shapes;
std::vector<TensorShape> target_shapes;
std::vector<Param> params;
//! testcase 0
params.push_back(param);
shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{10, 10});
//! testcase 1
params.push_back(param);
shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}});
target_shapes.push_back(TensorShape{3, 10, 10});
//! testcase 2
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{3, 6, 10});
//! testcase 3
param.transposeA = true;
param.transposeB = false;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}});
target_shapes.push_back(TensorShape{2, 3, 6, 10});
//! testcase 4
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});
//! testcase 5
param.transposeA = false;
param.transposeB = true;
params.push_back(param);
shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}});
target_shapes.push_back(TensorShape{2, 3, 7, 8});
for (size_t i = 0; i < params.size(); i++) {
auto& shape = shapes[i];
auto op = MatrixMul::make(
params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim,
shape.second.ndim);
auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second});
ASSERT_GT(result.size(), 0);
ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim);
for (size_t id = 0; id < target_shapes[i].ndim; id++) {
ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册