提交 fb4b4f8d 编写于 作者: K Krzysztof Binias

Refactor code

上级 50d3e6e9
......@@ -62,56 +62,42 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
}
std::shared_ptr<batch_norm_fwd> AcquireTestBatchNormFwd(
std::shared_ptr<batch_norm_fwd> AcquireTestTrainingBatchNormFwd(
std::shared_ptr<memory> src_memory,
const mkldnn::primitive::at &mean_memory,
const mkldnn::primitive::at &variance_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory) {
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
std::shared_ptr<memory> variance_memory, bool is_test) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for test in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, mean_memory, variance_memory,
*scaleshift_memory, *dst_memory);
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
PADDLE_ENFORCE((batch_norm_p != nullptr) || !is_reusing_,
"Fail to find batch norm primitive in device context");
std::shared_ptr<batch_norm_fwd> AcquireTrainingBatchNormFwd(
std::shared_ptr<memory> src_memory,
std::shared_ptr<memory> scaleshift_memory,
std::shared_ptr<memory> dst_memory, std::shared_ptr<memory> mean_memory,
std::shared_ptr<memory> variance_memory) {
auto prim_key = key_ + "@batch_norm_p";
auto batch_norm_p =
std::static_pointer_cast<batch_norm_fwd>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE(
(batch_norm_p != nullptr) || (is_reusing_ == false),
"Fail to find batch norm primitive for training in device context");
if (batch_norm_p == nullptr) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
*mean_memory, *variance_memory);
if (is_test) {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory,
(const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)*variance_memory, *scaleshift_memory,
*dst_memory);
} else {
batch_norm_p = std::make_shared<batch_norm_fwd>(
*batch_norm_pd_, *src_memory, *scaleshift_memory, *dst_memory,
*mean_memory, *variance_memory);
}
dev_ctx_.SetBlob(prim_key, batch_norm_p);
} else {
is_reusing_ = true;
}
return batch_norm_p;
}
//
static std::string GetHash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format,
const std::string &suffix) {
const std::string &suffix = "") {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
......@@ -128,19 +114,6 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd_;
};
std::string gethash(const memory::dims &input_dims, float epsilon,
unsigned flag, bool is_test, memory::format format) {
auto dims2str = [](const memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + std::to_string(epsilon) + std::to_string(flag) +
std::to_string(is_test) + std::to_string(format);
}
std::shared_ptr<memory> UpdateMemoryData(
const platform::MKLDNNDeviceContext &dev_ctx, const std::string &key,
void *new_ptr) {
......@@ -274,10 +247,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireVarianceMemoryFromPrimitive(
to_void_cast(variance_data));
batch_norm_p = handler.AcquireTestBatchNormFwd(
src_memory, (const mkldnn::primitive::at &)*mean_memory,
(const mkldnn::primitive::at &)*variance_memory, scaleshift_memory,
dst_memory);
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd(
src_memory, scaleshift_memory, dst_memory, mean_memory,
variance_memory, true);
} else {
// create mkldnn memory for stats (as output)
std::shared_ptr<memory> mean_memory =
......@@ -285,9 +257,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<memory> variance_memory =
handler.AcquireVarianceMemoryFromPrimitive(batch_variance_data);
batch_norm_p = handler.AcquireTrainingBatchNormFwd(
batch_norm_p = handler.AcquireTestTrainingBatchNormFwd(
src_memory, scaleshift_memory, dst_memory, mean_memory,
variance_memory);
variance_memory, false);
}
y->set_layout(DataLayout::kMKLDNN);
......@@ -377,7 +349,8 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// keys for primitives reuse
const std::string key_with_hash =
key + gethash(src_tz, epsilon, flags, false, input_format);
key + BatchNormMKLDNNHandler::GetHash(src_tz, epsilon, flags, false,
input_format);
const std::string key_batch_norm_bwd_p =
key_with_hash + "@batch_norm_bwd_p";
const std::string key_batch_norm_src_mem_p =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册