提交 065b8cbd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2700 use same batchnorm operation in PyNative mode and Graph mode

Merge pull request !2700 from chujinjin/change_pynative_batchnorm_same_as_graph_mode
...@@ -174,7 +174,13 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int ...@@ -174,7 +174,13 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_); device_shape = trans::TransShapeToDevice(host_shape, format_);
} else { } else {
host_shape = trans::PaddingShapeTo4d(host_shape); if (host_shape_.empty()) {
host_shape = trans::PaddingShapeTo4d(host_shape);
} else {
host_shape.clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize);
}
device_shape = trans::TransShapeToDevice(host_shape, format_); device_shape = trans::TransShapeToDevice(host_shape, format_);
} }
if (type_id_ != type) { if (type_id_ != type) {
......
...@@ -47,6 +47,7 @@ class AscendDeviceAddress : public DeviceAddress { ...@@ -47,6 +47,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const; const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const;
#endif #endif
private: private:
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const; bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type, bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
......
...@@ -63,6 +63,7 @@ class DeviceAddress { ...@@ -63,6 +63,7 @@ class DeviceAddress {
size_t GetSize() const { return size_; } size_t GetSize() const { return size_; }
std::string format() const { return format_; } std::string format() const { return format_; }
TypeId type_id() const { return type_id_; } TypeId type_id() const { return type_id_; }
void set_host_shape(const std::vector<int> &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {} virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
...@@ -77,6 +78,7 @@ class DeviceAddress { ...@@ -77,6 +78,7 @@ class DeviceAddress {
string format_{"DefaultFormat"}; string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16}; TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false}; bool from_mem_pool_{false};
std::vector<int> host_shape_{};
friend class KernelRuntime; friend class KernelRuntime;
friend class MemoryManager; friend class MemoryManager;
friend class mindspore::device::ascend::tasksink::TaskGenerator; friend class mindspore::device::ascend::tasksink::TaskGenerator;
......
...@@ -259,6 +259,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) { ...@@ -259,6 +259,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
if (!ret) { if (!ret) {
...@@ -507,7 +508,9 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in ...@@ -507,7 +508,9 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
} }
std::string output_format = AnfAlgo::GetOutputFormat(node, i); std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get()); auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get());
} }
} }
......
...@@ -238,16 +238,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap ...@@ -238,16 +238,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
} }
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) { ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
} else { ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
}
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) { if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
...@@ -287,8 +282,11 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne ...@@ -287,8 +282,11 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
} }
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
...@@ -84,14 +84,13 @@ class _BatchNorm(Cell): ...@@ -84,14 +84,13 @@ class _BatchNorm(Cell):
self.dtype = P.DType() self.dtype = P.DType()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend" self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum self.momentum = 1.0 - momentum
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
self.is_ge_backend = True self.is_ge_backend = True
else: else:
self.is_ge_backend = False self.is_ge_backend = False
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): if self.is_ge_backend or self.is_ascend:
self.bn_train = P.BatchNorm(is_training=True, self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps) epsilon=self.eps)
else: else:
...@@ -153,7 +152,7 @@ class _BatchNorm(Cell): ...@@ -153,7 +152,7 @@ class _BatchNorm(Cell):
if self.is_ge_backend and self.is_global: if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features) axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape) y = self._global_sync(x, axes, re_shape)
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): elif self.is_ge_backend or self.is_ascend:
if self.is_global: if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features) axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape) y = self._global_sync(x, axes, re_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册