提交 da7d605e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!683 WIP: specialize hyper map parameter

Merge pull request !683 from xychow/bypass-renorm-and-specialize-hypermap-parameter
......@@ -42,6 +42,7 @@ using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
const char FUNC_GRAPH_FLAG_CORE[] = "core";
const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
// ANF transform class
// either a primitive or a func_graph
......
......@@ -23,6 +23,7 @@
#include <sstream>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "pipeline/static_analysis/dshape.h"
......@@ -334,6 +335,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
ptrGraph->debug_info()->set_name("hyper_map");
AnfNodePtr ptrFnArg = nullptr;
......
......@@ -278,10 +278,12 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) {
// Convert class to Tuple
// Convert getattr to getitem
// Convert make_record to make_tuple
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(root);
bool changed = false;
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
AnfNodeSet all_node = manager->all_nodes();
for (auto &node : all_node) {
......@@ -316,7 +318,9 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
if (new_node != nullptr) {
new_node->set_abstract(node->abstract());
MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
(void)manager->Replace(node, new_node);
changed = true;
}
}
......@@ -324,6 +328,7 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
auto ret = Reabs(node->abstract());
node->set_abstract(ret);
}
return changed;
}
// expand tuples in graph parameters
......
......@@ -31,7 +31,7 @@ namespace mindspore {
namespace opt {
// Remove the class type from graphs
void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
// Remove most uses of tuples from the graph
// tuples that are returned will be kept
......
......@@ -38,13 +38,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod
// src type check
auto src_type = src_->Type();
if (src_type == nullptr) {
if (src_type == nullptr || !src_type->isa<TensorType>()) {
return nullptr;
}
if (src_type->isa<TensorType>()) {
src_type = src_type->cast<TensorTypePtr>()->element();
}
src_type = src_type->cast<TensorTypePtr>()->element();
// tgt type check
auto tgt_type = GetValueNode<TypePtr>(tgt_);
......
......@@ -52,14 +52,16 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
opt::SimplifyDataStructures(func_graph, res->manager());
bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
if (changed) {
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
}
res->set_args_spec(args_spec);
return true;
}
......
......@@ -177,8 +177,8 @@ std::size_t FuncGraphAbstractClosure::hash() const {
std::string FuncGraphAbstractClosure::ToString() const {
std::stringstream ss;
ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString()
<< "; Context: " << context_.get() << context_->ToString();
ss << "FuncGraphAbstractClosure: "
<< "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString();
return ss.str();
}
......
......@@ -166,8 +166,9 @@ class PartialAbstractClosure : public AbstractFuncAtom {
public:
// Represents a partial application.
// args_spec_list: The first few arguments of that function
PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list)
: fn_(fn), args_spec_list_(args_spec_list) {}
PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list,
const AnfNodePtr &node = nullptr)
: fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {}
~PartialAbstractClosure() override = default;
MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom)
......@@ -175,7 +176,11 @@ class PartialAbstractClosure : public AbstractFuncAtom {
AbstractFunctionPtr fn() { return fn_; }
AbstractBasePtrList args() { return args_spec_list_; }
AbstractFunctionPtr Copy() const override { return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_); }
AnfNodePtr node() { return node_.lock(); }
void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); }
AbstractFunctionPtr Copy() const override {
return std::make_shared<PartialAbstractClosure>(fn_, args_spec_list_, node_.lock());
}
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override;
......@@ -184,6 +189,8 @@ class PartialAbstractClosure : public AbstractFuncAtom {
private:
AbstractFuncAtomPtr fn_;
AbstractBasePtrList args_spec_list_;
// The CNode which this PartialAbstractClosure evaluated from.
AnfNodeWeakPtr node_;
};
class JTransformedAbstractClosure : public AbstractFuncAtom {
......
......@@ -951,8 +951,19 @@ class PartialEvaluator : public Evaluator {
if (args_conf_list.size() == 0) {
MS_LOG(EXCEPTION) << "Args size should be greater than 0";
}
MS_EXCEPTION_IF_NULL(out_conf);
MS_EXCEPTION_IF_NULL(out_conf->node());
auto arg0_value = args_conf_list[0]->GetEvaluatedValue();
AbstractBasePtrList args_spec_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) {
auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString();
(*cache_)[args_spec_list] = ret;
return ret;
}
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
// Sometimes, node[0] in out_conf becomes phi0;
if (func->isa<PrimitiveAbstractClosure>()) {
......@@ -962,19 +973,26 @@ class PartialEvaluator : public Evaluator {
return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
}
}
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); });
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); });
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
AbstractFuncAtomPtrList partialPtrList;
auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) {
auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args);
partialPtrList.push_back(new_func);
auto cnode = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != (args_conf_list.size() + 1)) {
MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
<< ", args_conf_list: " << mindspore::ToString(args_conf_list);
}
AbstractFuncAtomPtrList partial_funcs_list;
auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
partial_funcs_list.push_back(new_func);
};
func->Visit(build_partial);
auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList);
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
(*cache_)[args_spec_list] = ret;
return ret;
}
......
......@@ -23,7 +23,9 @@
#include "./common.h"
#include "operator/ops.h"
#include "operator/composite/do_signature.h"
#include "pipeline/static_analysis/abstract_function.h"
#include "utils/graph_utils.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"
#include "debug/trace.h"
......@@ -232,6 +234,13 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) {
return;
}
new_node->set_abstract(GetEvaluatedValueWrap(conf));
if (new_node->isa<CNode>() && new_node->abstract()->isa<PartialAbstractClosure>()) {
auto partial_abstract = dyn_cast<PartialAbstractClosure>(new_node->abstract());
if (partial_abstract->node() == node) {
partial_abstract->set_node(new_node);
}
}
MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString();
if (node->isa<CNode>()) {
......@@ -383,6 +392,56 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
return BuildValueNode(v, abs);
}
AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) {
auto new_inputs = new_node->inputs();
AnfNodePtr func = new_inputs[0];
AbstractBasePtr fnval = new_inputs[0]->abstract();
AbstractBasePtrList args;
auto backed_fnval = fnval;
if (fnval->isa<PartialAbstractClosure>()) {
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
backed_fnval = partial_closure->fn();
args = partial_closure->args();
}
std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args),
[](const AnfNodePtr &inp) { return inp->abstract(); });
ScopeGuard scope_guard(new_node->scope());
auto specialized_node = BuildSpecializedNode(func, backed_fnval, args);
auto wrapped_node = specialized_node;
if (fnval->isa<PartialAbstractClosure>()) {
auto partial_closure = dyn_cast<PartialAbstractClosure>(fnval);
AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)),
specialized_node};
auto anf_node = partial_closure->node();
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString();
}
auto cnode = anf_node->cast<CNodePtr>();
if (cnode->size() != partial_closure->args().size() + 2) {
MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString()
<< " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args());
}
for (size_t i = 0; i < partial_closure->args().size(); i++) {
auto old_node = cnode->input(i + 2);
auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]);
if (possibile_value_node != nullptr) {
partial_node_list.push_back(possibile_value_node);
} else {
if (!(old_node->isa<CNode>() || old_node->isa<Parameter>())) {
MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString();
}
partial_node_list.push_back(old_node);
}
}
wrapped_node = new_node->func_graph()->NewCNode(partial_node_list);
wrapped_node->set_abstract(partial_closure);
}
return wrapped_node;
}
const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) {
auto cache_iter = evalcaches_.find(eval);
if (cache_iter == evalcaches_.end()) {
......@@ -465,6 +524,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
}
if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}
if (CanSpecializeNode(func)) {
new_inputs[0] = BuildSpecializedNode(func, fnval, argvals);
}
......@@ -474,16 +538,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
if (CanSpecializeNode(args[i])) {
new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector<AbstractBasePtr>{});
}
// support for partial(Multitype) which Multitype should not be inferred to POLY.
// after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only,
// so even with partial parameter, it will specialize to that graph.
// Maybe a better idea should inline graph with partial node first, then it will have full
// parameter list to infer and specialize.
MS_EXCEPTION_IF_NULL(new_inputs[next]);
if (new_inputs[next]->isa<ValueNode>() && (GetValueNode(new_inputs[next]) == kPolyNode) &&
IsPrimitive(func, prim::kPrimPartial)) {
new_inputs[next] = args[i];
}
i = next;
}
......
......@@ -106,6 +106,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_this<FuncGraphSpecia
// (disconnected).
AnfNodePtr ReplicateDisconnectedNode(const AnfNodePtr &node);
// Build a value node from parameter if the function graph has special flag to hint it can be done.
AnfNodePtr BuildSpecializedParameterNode(const CNodePtr &new_node);
// Build a value node if ival is constant and not any-value
AnfNodePtr BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival);
// Build a replacable node for iconf->node; it may be a replicated forwared CNode in static analysis or just a
......
......@@ -87,11 +87,6 @@ class CumSumNet(nn.Cell):
raise_set = [
# one input is scalar, and another is Tensor(float32)
('TensorAdd0', {
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('TensorAdd1', {
'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}),
......@@ -271,11 +266,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Sub0', {
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Sub1', {
'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}),
......@@ -287,11 +277,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Mul0', {
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Mul1', {
'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}),
......@@ -352,11 +337,6 @@ raise_set = [
'desc_inputs': [5.0],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Minimum0', {
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Minimum1', {
'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}),
......@@ -368,11 +348,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Maximum0', {
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Maximum1', {
'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}),
......@@ -384,11 +359,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('RealDiv0', {
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('RealDiv1', {
'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}),
......@@ -400,11 +370,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Div0', {
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Div1', {
'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}),
......@@ -416,11 +381,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('FloorDiv0', {
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('FloorDiv1', {
'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}),
......@@ -439,11 +399,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('FloorMod0', {
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('FloorMod1', {
'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}),
......@@ -462,11 +417,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
'skip': ['backward']}),
# input is not tensor
('Equal0', {
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Equal1', {
'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}),
......@@ -490,11 +440,6 @@ raise_set = [
'skip': ['backward']}),
# shape of x and y not match
# input is not tensor
('NotEqual0', {
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('NotEqual1', {
'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}),
......@@ -506,11 +451,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Greater0', {
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Greater1', {
'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}),
......@@ -522,11 +462,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('GreaterEqual0', {
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('GreaterEqual1', {
'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}),
......@@ -538,11 +473,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('Less0', {
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('Less1', {
'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}),
......@@ -554,11 +484,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))],
'skip': ['backward']}),
# input is not tensor
('LessEqual0', {
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# type of x and y not match
('LessEqual1', {
'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}),
......@@ -728,11 +653,6 @@ raise_set = [
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))],
'skip': ['backward']}),
# one input is scalar, and another is Tensor(float32)
('Atan20', {
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input two tensors, but element types are not same
('Atan21', {
'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}),
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_hypermap_partial """
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.api import ms_function
context.set_context(mode=context.GRAPH_MODE)
def test_hypermap_specialize_param():
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.mul = P.Mul()
def construct(self, x, y):
ret = self.mul(x, y)
return ret
factor1 = Tensor(5, dtype=mstype.int32)
x = Tensor(np.ones([1]).astype(np.int32))
y = Tensor(np.ones([2]).astype(np.int32))
net = Net()
hypermap = C.HyperMap()
@ms_function
def hypermap_specialize_param():
ret1 = hypermap(F.partial(net, factor1), (x, y))
# List will be converted to Tuple in SimlifyDataStructurePass.
ret2 = hypermap(F.partial(net, factor1), [x, y])
return ret1, ret2
expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32)))
ret = hypermap_specialize_param()
assert(ret == (expected_ret, expected_ret))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册