提交 aff6777e 编写于 作者: B buxue

fix reviewboot and example of TruncatedNormal and add type mapping

上级 c6d21ccd
......@@ -134,7 +134,7 @@ class DebugInfo : public Base {
explicit DebugInfo(const LocationPtr &loc);
virtual ~DebugInfo() = default;
~DebugInfo() override = default;
MS_DECLARE_PARENT(DebugInfo, Base);
int64_t debug_id();
int64_t unique_id() const { return unique_id_; }
......
......@@ -231,10 +231,10 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) {
auto engine = node_cfg_->engine();
auto cfg = engine->MakeConfig(node, ctx);
auto abs = engine->cache().GetValue(cfg);
if (abs == nullptr) {
return "Undefined";
}
auto dtype = abs->BuildType();
auto shape = abs->BuildShape();
std::ostringstream oss;
......
......@@ -321,7 +321,7 @@ class TraceTransform : public TraceInfo {
std::string full_name() override { return full_name_ + transform_name_; }
MS_DECLARE_PARENT(TraceTransform, TraceInfo);
virtual std::string symbol() {
std::string symbol() override {
if (transform_name_.empty()) {
return "";
}
......
......@@ -87,6 +87,12 @@ const char *MetaIdLabel(const TypeId &v) {
return "kMetaTypeExternal";
case kMetaTypeNone:
return "kMetaTypeNone";
case kMetaTypeNull:
return "kMetaTypeNull";
case kMetaTypeEllipsis:
return "kMetaTypeEllipsis";
case kMetaTypeEnd:
return "kMetaTypeEnd";
default:
return "[Unknown Type Id]";
}
......
......@@ -133,7 +133,6 @@ ResolveIRPassLib::ResolveIRPassLib() {
InferenceOptPrepareLib::InferenceOptPrepareLib() {
grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode);
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......@@ -159,7 +159,6 @@ inline bool IsCNodeDup(const AnfNodePtr &node) {
}
return false;
}
} // namespace irpass
} // namespace opt
} // namespace mindspore
......
......@@ -31,7 +31,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
AnfNodePtr func_node, bool is_unpack, bool sens_param) {
MS_EXCEPTION_IF_NULL(func_graph);
......
......@@ -33,7 +33,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
// {{GradOperation, g, w}, Ys}
// {UnPackCall, {GradOperation, g, w}, Ys}
class GradVarPrepare : public AnfVisitor {
......
......@@ -28,13 +28,11 @@
namespace mindspore {
namespace pipeline {
struct ExecutorInfo {
FuncGraphPtr func_graph;
ResourcePtr resource;
std::size_t arg_list_size;
};
using ExecutorInfoPtr = std::shared_ptr<ExecutorInfo>;
inline std::string GetPhasePrefix(const std::string &phase) {
......
......@@ -101,7 +101,7 @@ py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::str
MS_LOG(INFO) << "Start new args and compile key:" << key;
g_args_cache[args_spec] = key++;
}
py::tuple argSpec = py::tuple(2);
auto argSpec = py::tuple(2);
argSpec[0] = name;
argSpec[1] = g_args_cache[args_spec];
return argSpec;
......
......@@ -52,11 +52,11 @@ void DoExecNonInputGraph(const std::string &phase) {
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(ERROR) << "Can not found GraphRunner";
return;
}
{
// Release GIL before calling into (potentially long-running) C++ code
py::gil_scoped_release release;
......@@ -181,7 +181,6 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
size_t pos = phase.find('.');
std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1));
std::string phase_prefix = phase.substr(0, pos);
if (phase_prefix == "export") {
MS_LOG(INFO) << "Set DfGraphConvertor training : false";
convertor.set_training(false);
......@@ -348,7 +347,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::t
auto data_tp = cnode_data->cast<AbstractTuplePtr>();
auto elements = data_tp->elements();
size_t size = data_tp->size();
py::tuple tp = py::tuple(size);
auto tp = py::tuple(size);
for (size_t i = 0; i < size; i++) {
tp[i] = ExtractGeneralCnodeRet(elements[i], data, count);
}
......@@ -379,7 +378,7 @@ py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data,
if (output_c->IsApply(prim::kPrimMakeTuple)) {
auto input_list = output_c->inputs();
size_t size = input_list.size();
py::tuple tp = py::tuple(size - 1);
auto tp = py::tuple(size - 1);
for (size_t i = 1; i < size; i++) {
tp[i - 1] = StructureOutput(input_list[i], data, count);
}
......@@ -401,11 +400,8 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr &graph, const std::ve
std::vector<GeTensorPtr> ge_outputs;
transform::RunOptions run_options;
run_options.name = phase;
auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner();
if (graph_runner == nullptr) {
MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
}
......@@ -478,7 +474,6 @@ void ProcessGeArg(const std::map<std::string, ExecutorInfoPtr> &info, const py::
py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::tuple &args,
const std::string &phase) {
std::string phase_prefix = GetPhasePrefix(phase);
if (phase_prefix == "save") {
DoExecNonInputGraph(phase);
ConfigManager::GetInstance().ResetConfig();
......@@ -488,7 +483,6 @@ py::object ExecDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const
if (info.count(phase) == 0) {
MS_LOG(EXCEPTION) << "There is no phase:" << phase;
}
FuncGraphPtr anf_graph = info.at(phase)->func_graph;
#ifdef ENABLE_INFER
......
......@@ -31,7 +31,6 @@
namespace mindspore {
namespace pipeline {
namespace py = pybind11;
void SetGeOption(const std::map<std::string, std::string> &options);
......@@ -50,7 +49,6 @@ bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batc
const std::vector<int64_t> &input_indexes, const std::string &phase);
void ExportDFGraph(const std::string &file_name, const std::string &phase);
} // namespace pipeline
} // namespace mindspore
......
......@@ -41,7 +41,7 @@ class AbstractFuncAtom : public AbstractFunction {
AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final;
void Visit(std::function<void(const AbstractFuncAtomPtr &)>) const final;
bool operator==(const AbstractFunction &other) const;
bool operator==(const AbstractFunction &other) const override;
std::size_t hash() const override { return tid(); }
};
......@@ -270,7 +270,7 @@ class TypedPrimitiveAbstractClosure : public AbstractFuncAtom {
class DummyAbstractClosure : public AbstractFuncAtom {
public:
DummyAbstractClosure() = default;
~DummyAbstractClosure() = default;
~DummyAbstractClosure() override = default;
MS_DECLARE_PARENT(DummyAbstractClosure, AbstractFuncAtom)
EvaluatorPtr GetEvaluator(AnalysisEnginePtr) override { MS_LOG(EXCEPTION) << "A dummy function cannot eval."; }
......
......@@ -295,7 +295,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic["shape"] = shape;
dic["dtype"] = arg_slice->BuildType();
dic["value"] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractTuple>()) {
auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
size_t len = arg_tuple->size();
......
......@@ -639,9 +639,9 @@ class TruncatedNormal(PrimitiveWithInfer):
Tensor, type of output tensor is same as attribute `dtype`.
Examples:
>>> input_shape = Tensor(np.array([1, 2, 3]))
>>> shape = (1, 2, 3)
>>> truncated_normal = P.TruncatedNormal()
>>> output = truncated_normal(input_shape)
>>> output = truncated_normal(shape)
"""
@prim_attr_register
......@@ -652,6 +652,8 @@ class TruncatedNormal(PrimitiveWithInfer):
def __infer__(self, shape):
shape_value = shape['value']
validator.check_const_input("shape", shape_value)
validator.check_type("shape", shape_value, [tuple])
for i, value in enumerate(shape_value):
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
out = {'shape': shape_value,
......@@ -1642,15 +1644,16 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
def __infer__(self, x, begin, end, strides):
x_shape = x['shape']
x_shp_len = len(x_shape)
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
validator.check_const_input("begin", begin_v)
validator.check_const_input("end", end_v)
validator.check_const_input("strides", strides_v)
validator.check_type("begin", begin['value'], [tuple])
validator.check_type("end", end['value'], [tuple])
validator.check_type("strides", strides['value'], [tuple])
validator.check_type("begin", begin_v, [tuple])
validator.check_type("end", end_v, [tuple])
validator.check_type("strides", strides_v, [tuple])
x_shape = x['shape']
x_shp_len = len(x_shape)
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
f"must be equal to the dims({x_shp_len}) of input.")
......
......@@ -372,7 +372,7 @@ test_case_math_ops = [
'desc_bprop': [[3]]}),
('TruncatedNormal', {
'block': P.TruncatedNormal(),
'desc_const': [[1, 2, 3]],
'desc_const': [(1, 2, 3)],
'desc_inputs': [],
'skip': ['backward'],
'add_fake_input': True}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册