提交 787f187e 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/src): fix dot backward error

GitOrigin-RevId: 02ba44a0e6d8cd2ca863ae0058542b260e0b755d
上级 f35687ca
...@@ -442,3 +442,18 @@ def test_removeAxis(): ...@@ -442,3 +442,18 @@ def test_removeAxis():
grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())
def test_dot():
x = np.random.rand(2, 2).astype("float32")
x = mge.Tensor(x)
u = F.ones((2,))
v = F.ones((2,))
grad = Grad().wrt(x, callback=save_to(x))
def f(x):
return F.dot(u, F.matmul(x, v))
y = f(x)
grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy())
...@@ -33,7 +33,7 @@ DispatchMode decide_dispatch_mode( ...@@ -33,7 +33,7 @@ DispatchMode decide_dispatch_mode(
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
bool host_computable = true; bool host_computable = true;
for (auto&& inp : inputs) { for (auto&& inp : inputs) {
// FIXME(czh): remove value chech after proxy graph's // FIXME(czh): remove value check after proxy graph's
// apply_on_device_tensornd is supported and output Tensor // apply_on_device_tensornd is supported and output Tensor
// is made before add_task. // is made before add_task.
// then if layout is valid, ptr->layout must be ready // then if layout is valid, ptr->layout must be ready
...@@ -50,9 +50,18 @@ void apply_on_device_tensornd( ...@@ -50,9 +50,18 @@ void apply_on_device_tensornd(
const SmallVector<DeviceTensorND>& inputs, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) { SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>(); auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& inp = inputs[0]; TensorShape shp;
auto&& shp = inp.layout(); if (inputs.size() == 1) {
shp = inputs[0].layout();
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout();
}
megdnn::Elemwise::deduce_shape(src, shp);
}
mgb_assert(shp.ndim != 0, "input shape invalid"); mgb_assert(shp.ndim != 0, "input shape invalid");
mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(),
"GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs.");
...@@ -99,27 +108,36 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -99,27 +108,36 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>(); auto&& op_def = def.cast_final_safe<GetVarShape>();
mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size());
auto&& desc = inputs[0]; auto&& desc = inputs[0];
if (!desc.layout.ndim) { TensorShape shp;
if (inputs.size() == 1) {
shp = desc.layout;
} else {
TensorShapeArray src(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
src[i] = inputs[i].layout;
}
megdnn::Elemwise::deduce_shape(src, shp);
}
if (!shp.ndim) {
return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false};
} }
DeviceTensorND value; DeviceTensorND value;
if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) {
value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>(); auto* ptr = value.ptr<dt_int32>();
for (size_t i = 0; i < desc.layout.ndim; ++i) { for (size_t i = 0; i < shp.ndim; ++i) {
ptr[i] = desc.layout[i]; ptr[i] = shp[i];
} }
}else{ }else{
int32_t axis = op_def.axis; int32_t axis = op_def.axis;
if (axis < 0) { if (axis < 0) {
axis += desc.layout.ndim; axis += shp.ndim;
} }
mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim);
value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32());
auto* ptr = value.ptr<dt_int32>(); auto* ptr = value.ptr<dt_int32>();
ptr[0] = desc.layout[axis]; ptr[0] = shp[axis];
} }
return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; return {{{value.layout(), desc.comp_node, std::move(value)}}, true};
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册