提交 80484245 编写于 作者: J Jacek Czaja

- disabled caching of layer norm

上级 1148ce67
...@@ -19,45 +19,36 @@ namespace paddle { ...@@ -19,45 +19,36 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class LayerNormMKLDNNHandler class LayerNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
: public platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward> { T, dnnl::layer_normalization_forward> {
public: public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon, LayerNormMKLDNNHandler(const std::vector<int64_t>& dims, const float& epsilon,
const dnnl::normalization_flags& flags, const dnnl::normalization_flags& flags,
const bool& is_test, const MKLDNNMemoryFormat fmt, const bool& is_test, const MKLDNNMemoryFormat fmt,
const platform::MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place)
platform::Place cpu_place, : platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
const std::string& uniq_name) mkldnn_engine, cpu_place) {
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>( auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
dev_ctx, dev_ctx.GetEngine(), cpu_place, if (!is_test) {
platform::CreateKey(dev_ctx, dims, uniq_name)) { // TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced
if (!this->isCached()) { auto stats_md = dnnl::memory::desc(
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); {begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(),
if (!is_test) { platform::MKLDNNFormatForSize(dims.size() - 1,
// TODO(grygielski) Delete forcing stats_md after DNNL 1.2 is introduced MKLDNNMemoryFormat::nchw));
auto stats_md = dnnl::memory::desc( this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
{begin(dims), end(dims) - 1}, platform::MKLDNNGetDataType<float>(), md, stats_md, epsilon, flags);
platform::MKLDNNFormatForSize(dims.size() - 1, } else {
MKLDNNMemoryFormat::nchw)); this->AcquireForwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptor( dnnl::prop_kind::forward_inference, md, epsilon, flags);
dnnl::prop_kind::forward_training, md, stats_md, epsilon, flags);
} else {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, md, epsilon, flags);
}
} }
} }
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory() {
return this->AcquireMemoryFromPrimitive("@scaleshift_mem_p");
}
std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory( std::shared_ptr<dnnl::memory> AcquireScaleShiftMemory(
std::vector<float>& scaleshift_data) { std::vector<float>& scaleshift_data) {
// scaleshift_data comes from temporary buffer so we need to copy it into // scaleshift_data comes from temporary buffer so we need to copy it into
// created memory primitivie // created memory primitivie
auto scaleshift_mem = this->AcquireMemoryFromPrimitive( auto scaleshift_mem =
this->fwd_pd_->weights_desc(), "@scaleshift_mem_p"); this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc());
auto data_ptr = scaleshift_mem->get_data_handle(); auto data_ptr = scaleshift_mem->get_data_handle();
std::size_t num_bytes = scaleshift_data.size() * sizeof(float); std::size_t num_bytes = scaleshift_data.size() * sizeof(float);
std::memcpy(data_ptr, scaleshift_data.data(), num_bytes); std::memcpy(data_ptr, scaleshift_data.data(), num_bytes);
...@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler ...@@ -68,7 +59,7 @@ class LayerNormMKLDNNHandler
T* mean_data = mean->mutable_data<T>(this->place_, T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size()); this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
mean_data, "@mean_mem_p"); mean_data);
} }
std::shared_ptr<dnnl::memory> AcquireVarianceMemory( std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
...@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler ...@@ -76,7 +67,7 @@ class LayerNormMKLDNNHandler
T* variance_data = variance->mutable_data<T>( T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size()); this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
variance_data, "@variance_mem_p"); variance_data);
} }
}; };
...@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,6 +86,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto src_tz = paddle::framework::vectorize(x->dims()); auto src_tz = paddle::framework::vectorize(x->dims());
PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1), PADDLE_ENFORCE_EQ(begin_norm_axis, (src_tz.size() - 1),
...@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -112,8 +104,8 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test, LayerNormMKLDNNHandler<T> handler(src_tz, epsilon, flags, is_test,
x->format(), dev_ctx, ctx.GetPlace(), x->format(), mkldnn_engine,
ctx.OutputName("Y")); ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x); auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(y); auto dst_memory = handler.AcquireDstMemory(y);
...@@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -139,24 +131,22 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
args.insert({DNNL_ARG_VARIANCE, *variance_memory}); args.insert({DNNL_ARG_VARIANCE, *variance_memory});
} }
auto scaleshift_memory = handler.AcquireScaleShiftMemory(); auto scaleshift_memory = nullptr;
if (with_scaleshift) { if (with_scaleshift) {
if (scaleshift_memory == nullptr || !is_test) { auto scale_tz = paddle::framework::vectorize(scale->dims());
auto scale_tz = paddle::framework::vectorize(scale->dims()); const unsigned int C = scale_tz[0];
const unsigned int C = scale_tz[0];
// MKLDNN requires a single piece of memory for scale and shift/bias
// MKLDNN requires a single piece of memory for scale and shift/bias // data
// data std::vector<float> scaleshift_data;
std::vector<float> scaleshift_data; scaleshift_data.reserve(2 * C);
scaleshift_data.reserve(2 * C); scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(),
scaleshift_data.insert(scaleshift_data.begin(), scale->data<float>(), scale->data<float>() + C);
scale->data<float>() + C);
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(),
scaleshift_data.insert(scaleshift_data.end(), bias->data<float>(), bias->data<float>() + C);
bias->data<float>() + C);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
scaleshift_memory = handler.AcquireScaleShiftMemory(scaleshift_data);
}
args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory}); args.insert({DNNL_ARG_SCALE_SHIFT, *scaleshift_memory});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册