提交 f17c822c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2064 add memory context variable

Merge pull request !2064 from caifubi/add-memory-context-var
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <string>
#include "device/ascend/ascend_memory_manager.h" #include "device/ascend/ascend_memory_manager.h"
#include "device/ascend/ascend_memory_pool.h" #include "device/ascend/ascend_memory_pool.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
...@@ -21,25 +21,52 @@ ...@@ -21,25 +21,52 @@
namespace mindspore { namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
const uint64_t kAscendDeviceMemGB = 26; constexpr uint64_t kAscendDeviceMemGB = 26;
const uint64_t kAscendMemPoolGB = 4; constexpr uint64_t kAscendMemPoolGB = 4;
const uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << 30); constexpr uint64_t kMemSizeGB = 30;
const uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << 30); constexpr uint64_t kMaxMemSizeGB = 30;
constexpr uint64_t kAscendDeviceMemSize = (kAscendDeviceMemGB << kMemSizeGB);
constexpr uint64_t kAscendMemPoolSize = (kAscendMemPoolGB << kMemSizeGB);
void AscendMemoryManager::MallocDeviceMemory() { void AscendMemoryManager::MallocDeviceMemory() {
device_mem_size_ = kAscendDeviceMemSize; auto context_mem = GetDeviceMemSizeFromContext();
device_mem_size_ = context_mem == 0 ? kAscendDeviceMemSize : context_mem;
static_mem_offset_ = device_mem_size_; static_mem_offset_ = device_mem_size_;
auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM); auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), static_mem_offset_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << static_mem_offset_ << "] fail, ret[" << ret << "]";
} }
device_mem_pool_size_ = kAscendMemPoolSize;
ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM); if (context_mem == 0) {
if (ret != RT_ERROR_NONE) { device_mem_pool_size_ = kAscendMemPoolSize;
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]"; ret = rtMalloc(reinterpret_cast<void **>(&device_mem_pool_base_), device_mem_pool_size_, RT_MEMORY_HBM);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_pool_size_ << "] fail, ret[" << ret << "]";
}
AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_);
AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
}
}
uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto variable_memory_max_size = context->variable_memory_max_size();
if (variable_memory_max_size == "0") {
return 0;
}
MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size;
auto pos = variable_memory_max_size.find('*');
if (pos == std::string::npos) {
MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size";
}
auto gb_str = variable_memory_max_size.substr(0, pos);
auto gb_var = std::stoull(gb_str);
MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var;
if (gb_var > kMaxMemSizeGB || gb_var == 0) {
MS_LOG(EXCEPTION) << "Invalid allocate memory size:" << gb_var << " which should be in (0-30]GB";
} }
AscendMemoryPool::GetInstance().set_device_mem_pool_base(device_mem_pool_base_); return gb_var << kMemSizeGB;
AscendMemoryPool::GetInstance().set_device_mem_pool_size(device_mem_pool_size_);
} }
void AscendMemoryManager::FreeDeviceMemory() { void AscendMemoryManager::FreeDeviceMemory() {
......
...@@ -32,6 +32,8 @@ class AscendMemoryManager : public MemoryManager { ...@@ -32,6 +32,8 @@ class AscendMemoryManager : public MemoryManager {
private: private:
uint8_t *device_mem_pool_base_{nullptr}; uint8_t *device_mem_pool_base_{nullptr};
uint64_t device_mem_pool_size_{0}; uint64_t device_mem_pool_size_{0};
uint64_t GetDeviceMemSizeFromContext();
}; };
} // namespace ascend } // namespace ascend
} // namespace device } // namespace device
......
...@@ -140,6 +140,10 @@ class MsContext { ...@@ -140,6 +140,10 @@ class MsContext {
variable_memory_max_size_ = variable_memory_max_size; variable_memory_max_size_ = variable_memory_max_size;
} }
const std::string &variable_memory_max_size() const { return variable_memory_max_size_; }
const std::string &graph_memory_max_size() const { return graph_memory_max_size_; }
void set_enable_profiling(bool flag) { profiling_mode_ = flag; } void set_enable_profiling(bool flag) { profiling_mode_ = flag; }
bool enable_profiling() const { return profiling_mode_; } bool enable_profiling() const { return profiling_mode_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册