提交 80484245 编写于 作者: J Jacek Czaja

- disabled caching of layer norm

上级 1148ce67
......@@ -19,19 +19,15 @@ namespace paddle {
namespace operators {
template <typename T>
class LayerNormMKLDNNHandler
: public platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward> {
class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
T, dnnl::layer_normalization_forward> {
public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
const dnnl::normalization_flags& flags,
const bool& is_test, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, dims, uniq_name)) {
if (!this->isCached()) {
const mkldnn::engine engine, platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
mkldnn_engine, cpu_place) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) {
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
......@@ -39,25 +35,20 @@ class LayerNormMKLDNNHandler
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
platform::MKLDNNFormatForSize(dims.size() - 1,
MKLDNNMemoryFormat::nchw));
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags);
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
md, stats_md, epsilon, flags);
} else {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags);
}
}
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory() {
return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
std::vector<float>& scaleshift_data) {
// scaleshift_data comes from temporary buffer so we need to copy it into
// created memory primitivie
auto scaleshift_mem = this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p");
auto scaleshift_mem =
this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
......@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler
T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p");
mean_data);
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
......@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler
T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p");
variance_data);
}
};
......@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto src_tz = paddle::framework::vectorize(x->dims());
PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1),
......@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test,
x->format(), dev_ctx, ctx.GetPlace(),
ctx.OutputName("Y"));
x->format(), mkldnn_engine,
ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(y);
......@@ -139,9 +131,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory});
}
auto scaleshift_memory = handler.AcquireScaleShiftMemory();
auto scaleshift_memory = nullptr;
if (with_scaleshift) {
if (scaleshift_memory == nullptr || !is_test) {
auto scale_tz = paddle::framework::vectorize(scale->dims());
const unsigned int C = scale_tz[0];
......@@ -156,7 +147,6 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
}
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册