提交 2b24a801 编写于 作者: J Jacek Czaja

- Removed manual caching of activation

上级 b8322848
...@@ -105,13 +105,14 @@ template <typename T> ...@@ -105,13 +105,14 @@ template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx, void eltwise_grad(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) { mkldnn::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X")); algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, diff_y);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x); auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y); auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
...@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
...@@ -36,11 +37,10 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> { ...@@ -36,11 +37,10 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
bool is_inplaced = x->IsSharedBufferWith(*out); bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler( platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x, mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(), x);
ctx.InputName("X"), is_inplaced);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out); auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive(); auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册