提交 40251d95 编写于 作者: L lirongzhen1

configure auto parallel tensors shape

上级 6fb55381
...@@ -118,11 +118,9 @@ const size_t UndeterminedShapeType::fields_num = 6; ...@@ -118,11 +118,9 @@ const size_t UndeterminedShapeType::fields_num = 6;
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs; std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
if (!g_undetermined_configs.empty()) {
return;
}
std::string tmp; std::string tmp;
std::stringstream input(sparse_shape_types); std::stringstream input(sparse_shape_types);
g_undetermined_configs.clear();
while (std::getline(input, tmp, ';')) { while (std::getline(input, tmp, ';')) {
auto config = UndeterminedShapeType(tmp); auto config = UndeterminedShapeType(tmp);
g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
...@@ -145,17 +143,19 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt ...@@ -145,17 +143,19 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
if (!key->sparse_grad().empty()) { if (!key->sparse_grad().empty()) {
// Will be fixed once undetermined type ready // Will be fixed once undetermined type ready
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); if (g_undetermined_configs.empty()) {
if (sparse_shape_types.empty()) { auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types;
if (sparse_shape_types.empty()) {
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
}
InitUndeterminedFromEnv(sparse_shape_types);
} }
InitUndeterminedFromEnv(sparse_shape_types);
auto shape_types = g_undetermined_configs.find(key->sparse_grad()); auto shape_types = g_undetermined_configs.find(key->sparse_grad());
if (shape_types == g_undetermined_configs.end()) { if (shape_types == g_undetermined_configs.end()) {
MS_LOG(EXCEPTION) << "Param " << key->ToString() MS_LOG(EXCEPTION) << "Param " << key->ToString()
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES: " << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES";
<< sparse_shape_types;
} }
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
AbstractBasePtrList sparse_list; AbstractBasePtrList sparse_list;
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "pipeline/static_analysis/prim.h"
using mindspore::tensor::Tensor; using mindspore::tensor::Tensor;
...@@ -1371,6 +1372,11 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { ...@@ -1371,6 +1372,11 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
<< cloned_index << ", but not found the be cloned parameter"; << cloned_index << ", but not found the be cloned parameter";
} }
} }
std::string env = common::GetEnv("SLICE_ENV");
if (!env.empty()) {
MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env;
abstract::InitUndeterminedFromEnv(env);
}
} }
void SetVirtualDatasetStrategy(const CNodePtr &node) { void SetVirtualDatasetStrategy(const CNodePtr &node) {
......
...@@ -349,6 +349,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv ...@@ -349,6 +349,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
void InitUndeterminedFromEnv(const std::string &sparse_shape_types);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册