提交 b8322848 编写于 作者: J Jacek Czaja

- Activation onednn caching removed

上级 f6e981f2
...@@ -79,15 +79,14 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -79,15 +79,14 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL eletwise_forward must use CPUPlace")); "Operator DNNL eletwise_forward must use CPUPlace"));
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out"); auto *y = ctx.Output<Tensor>("Out");
bool is_inplaced = x->IsSharedBufferWith(*y); bool is_inplaced = x->IsSharedBufferWith(*y);
platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, dev_ctx, platform::ActivationMKLDNNHandler<T> handler(algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x);
ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y); auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
......
...@@ -857,7 +857,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<T, dnnl::bi ...@@ -857,7 +857,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<T, dnnl::bi
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data), "@src1_mem_p"); this->fwd_pd_->src1_desc(), to_void_cast<T>(input_data));
} }
private: private:
...@@ -980,24 +980,15 @@ class ReductionMKLDNNHandler ...@@ -980,24 +980,15 @@ class ReductionMKLDNNHandler
template <typename T> template <typename T>
class ActivationMKLDNNHandler class ActivationMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::eltwise_forward, : public MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward> { mkldnn::eltwise_backward> {
public: public:
ActivationMKLDNNHandler(mkldnn::algorithm algorithm, ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place, const mkldnn::engine engine, Place cpu_place,
const framework::Tensor* in_x, const framework::Tensor* in_x)
const std::string& unique_name, bool is_inplaced) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, mkldnn::eltwise_backward>(engine, cpu_place) {
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
is_inplaced ? platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
algorithm, unique_name)
: platform::CreateKey(
dev_ctx, framework::vectorize(in_x->dims()), "a",
unique_name)) {
if (!this->isCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0; float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// eltwise_linear means we are in scale op // eltwise_linear means we are in scale op
...@@ -1036,19 +1027,13 @@ class ActivationMKLDNNHandler ...@@ -1036,19 +1027,13 @@ class ActivationMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta); mkldnn::prop_kind::forward_training, algorithm, md, alpha, beta);
} }
}
ActivationMKLDNNHandler(mkldnn::algorithm algorithm, ActivationMKLDNNHandler(mkldnn::algorithm algorithm,
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx, Place cpu_place, const mkldnn::engine engine, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad, const framework::Tensor* in_x, const Tensor* out_grad)
const std::string& unique_name) : platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward,
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, mkldnn::eltwise_backward>(engine, cpu_place) {
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in_x->dims()),
"a", unique_name)) {
if (!this->isBwdCached()) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0; float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
...@@ -1077,14 +1062,11 @@ class ActivationMKLDNNHandler ...@@ -1077,14 +1062,11 @@ class ActivationMKLDNNHandler
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta); alpha, beta);
} }
}
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory( std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), to_void_cast<T>(input_data));
to_void_cast<T>(input_data),
"@bwd-src_mem_p");
} }
}; };
...@@ -1635,11 +1617,6 @@ using ConvMKLDNNHandler = ...@@ -1635,11 +1617,6 @@ using ConvMKLDNNHandler =
mkldnn::convolution_backward_data, mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>; mkldnn::convolution_backward_weights>;
using ConvTransposeMKLDNNHandler =
ConvMKLDNNTemplateHandler<mkldnn::deconvolution_forward,
mkldnn::deconvolution_backward_data,
mkldnn::deconvolution_backward_weights>;
template <typename T> template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory( static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册