提交 0da0bdcf 编写于 作者: B buxue

Fix bug structure output when there is depend whose first input is constant in outputs

上级 04eab166
......@@ -725,23 +725,15 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) {
return BaseRefToPyData(value);
}
py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data, size_t* count) {
MS_EXCEPTION_IF_NULL(output);
py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) {
MS_EXCEPTION_IF_NULL(cnode_data);
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (!output->isa<AbstractTuple>()) {
ValuePtr value = output->BuildValue();
if (value != kAnyValue) {
return ValuePtrToPyData(value);
}
if (!output->isa<AbstractTensor>()) {
MS_LOG(EXCEPTION) << "Output can only be tensor except for constants, but got "
<< output->BuildValue()->ToString() << ".";
}
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
auto shape = output->BuildShape();
if (cnode_data->isa<AbstractTensor>()) {
BaseShapePtr shape = cnode_data->BuildShape();
auto shape_act = shape->cast<abstract::ShapePtr>()->shape();
Tensor tensor_exp = py::cast<Tensor>(data[*count]);
if (shape_act != tensor_exp.shape()) {
......@@ -751,16 +743,58 @@ py::object StructureOutput(const AbstractBasePtr& output, const py::tuple& data,
return data[(*count)++];
}
auto tuple_output = output->cast<AbstractTuplePtr>();
AbstractBasePtrList elements = tuple_output->elements();
size_t size = elements.size();
if (!cnode_data->isa<AbstractTuple>()) {
MS_LOG(EXCEPTION) << "The output of operator in the final anf graph could "
<< "only be a tensor or a tuple of tensor, but got " << cnode_data->BuildValue()->ToString()
<< ".";
}
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
auto elements = data_tp->elements();
size_t size = data_tp->size();
py::tuple tp = py::tuple(size);
for (size_t i = 0; i < size; i++) {
tp[i] = StructureOutput(elements[i], data, count);
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
}
return std::move(tp);
}
py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) {
MS_EXCEPTION_IF_NULL(output_node);
if (output_node->isa<ValueNode>()) {
return ValuePtrToPyData(GetValueNode(output_node));
}
if (*count >= data.size()) {
MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size()
<< " less than the number of elements required. ";
}
if (output_node->isa<Parameter>()) {
return data[(*count)++];
}
auto output_c = output_node->cast<CNodePtr>();
if (output_c == nullptr) {
MS_LOG(EXCEPTION) << "The final anf graph could only have constant, parameter, and operator, but got "
<< output_node->ToString();
}
if (output_c->IsApply(prim::kPrimMakeTuple)) {
auto input_list = output_c->inputs();
size_t size = input_list.size();
py::tuple tp = py::tuple(size - 1);
for (size_t i = 1; i < size; i++) {
tp[i - 1] = StructureOutput(input_list[i], data, count);
}
return std::move(tp);
}
if (output_c->IsApply(prim::kPrimDepend)) {
return StructureOutput(output_c->input(1), data, count);
}
return ExtractGeneralCnodeRet(output_c->abstract(), data, count);
}
std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::vector<MeTensorPtr>& inputs,
const std::string& phase) {
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW);
......@@ -806,11 +840,10 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve
std::shared_ptr<py::object> ret = nullptr;
#ifdef ENABLE_GE
AnfNodePtr root = graph->get_return();
MS_EXCEPTION_IF_NULL(root);
AbstractBasePtr output = root->abstract();
AnfNodePtr output_node = graph->get_return()->input(1);
MS_EXCEPTION_IF_NULL(output_node);
size_t count = 0;
py::object oj = StructureOutput(output, outputs, &count);
py::object oj = StructureOutput(output_node, outputs, &count);
ret = std::make_shared<py::object>(oj);
#else
if (outputs.size() == 1) {
......
......@@ -236,7 +236,7 @@ def test_soft():
def __init__(self):
super(SoftmaxCrossEntropyWithLogitsNet, self).__init__()
self.soft = P.SoftmaxCrossEntropyWithLogits()
self.value = (Tensor(np.zeros((2,)).astype(np.float32)), Tensor(np.ones((2,)).astype(np.float32)))
self.value = (Tensor(np.zeros((2, 2)).astype(np.float32)), Tensor(np.ones((2, 2)).astype(np.float32)))
def construct(self, x, y, z):
xx = x + y
......@@ -246,8 +246,30 @@ def test_soft():
ret = (ret, self.value)
return ret
input1 = Tensor(np.zeros((2,)).astype(np.float32))
input2 = Tensor(np.ones((2,)).astype(np.float32))
input3 = Tensor((np.ones((2,)) + np.ones((2,))).astype(np.float32))
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
net = SoftmaxCrossEntropyWithLogitsNet()
print(net(input1, input2, input3))
net(input1, input2, input3)
def test_const_depend():
class ConstDepend(Cell):
def __init__(self):
super(ConstDepend, self).__init__()
self.value = (Tensor(np.zeros((2, 3)).astype(np.float32)), Tensor(np.ones((2, 3)).astype(np.float32)))
self.soft = P.SoftmaxCrossEntropyWithLogits()
self.depend = depend
def construct(self, x, y, z):
ret = x + y
ret = ret * z
ret = self.depend(self.value, ret)
ret = (ret, self.soft(x, y))
return ret
input1 = Tensor(np.zeros((2, 2)).astype(np.float32))
input2 = Tensor(np.ones((2, 2)).astype(np.float32))
input3 = Tensor((np.ones((2, 2)) + np.ones((2, 2))).astype(np.float32))
net = ConstDepend()
net(input1, input2, input3)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册