提交 065b8cbd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2700 use same batchnorm operation in PyNative mode and Graph mode

Merge pull request !2700 from chujinjin/change_pynative_batchnorm_same_as_graph_mode
......@@ -174,7 +174,13 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);
if (host_shape_.empty()) {
host_shape = trans::PaddingShapeTo4d(host_shape);
} else {
host_shape.clear();
(void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(host_shape), IntToSize);
}
device_shape = trans::TransShapeToDevice(host_shape, format_);
}
if (type_id_ != type) {
......
......@@ -47,6 +47,7 @@ class AscendDeviceAddress : public DeviceAddress {
bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt,
const std::vector<int> &host_shape, TypeId host_type, size_t slot, Debugger *debugger) const;
#endif
private:
bool SyncDeviceToHostAndConvertFormat(const std::vector<int> &shape, size_t size, TypeId type, void *host_ptr) const;
bool ConvertFormatAndSyncHostToDevice(const std::vector<int> &shape, size_t size, TypeId type,
......
......@@ -63,6 +63,7 @@ class DeviceAddress {
size_t GetSize() const { return size_; }
std::string format() const { return format_; }
TypeId type_id() const { return type_id_; }
void set_host_shape(const std::vector<int> &shape) { host_shape_ = shape; }
virtual void set_status(DeviceAddressStatus status) {}
virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; }
virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; }
......@@ -77,6 +78,7 @@ class DeviceAddress {
string format_{"DefaultFormat"};
TypeId type_id_{kNumberTypeFloat16};
bool from_mem_pool_{false};
std::vector<int> host_shape_{};
friend class KernelRuntime;
friend class MemoryManager;
friend class mindspore::device::ascend::tasksink::TaskGenerator;
......
......@@ -259,6 +259,7 @@ void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel) {
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
MS_EXCEPTION_IF_NULL(device_address);
auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
if (!ret) {
......@@ -507,7 +508,9 @@ void KernelRuntime::AssignNodeOutputMem(int flag, const AnfNodePtr &node, int in
}
std::string output_format = AnfAlgo::GetOutputFormat(node, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
AnfAlgo::SetOutputAddr(CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type), i, node.get());
auto device_address = CreateDeviceAddress(ptr, output_sizes[i], output_format, output_type);
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
AnfAlgo::SetOutputAddr(device_address, i, node.get());
}
}
......
......@@ -238,16 +238,11 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
if (context_ptr->execution_mode() == kPynativeMode) {
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>());
} else {
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
}
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
......@@ -287,8 +282,11 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
}
auto optimizer = std::make_shared<GraphOptimizer>();
auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm");
ir_fusion_pm->AddPass(std::make_shared<BnSplit>());
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormFusion>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>());
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>());
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
......
......@@ -84,14 +84,13 @@ class _BatchNorm(Cell):
self.dtype = P.DType()
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
else:
self.is_ge_backend = False
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
if self.is_ge_backend or self.is_ascend:
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)
else:
......@@ -153,7 +152,7 @@ class _BatchNorm(Cell):
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
elif self.is_ge_backend or self.is_ascend:
if self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册