提交 2b24a801 编写于 作者: J Jacek Czaja

- Removed manual caching of activation

上级 b8322848
......@@ -105,13 +105,14 @@ template <typename T>
void eltwise_grad(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<Tensor>("X");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
platform::ActivationMKLDNNHandler<T> handler(
algorithm, ctx, dev_ctx, ctx.GetPlace(), x, diff_y, ctx.InputName("X"));
algorithm, ctx, mkldnn_engine, ctx.GetPlace(), x, diff_y);
auto src_memory_p = handler.AcquireBackwardSrcMemory(x);
auto diff_dst_memory_p = handler.AcquireDiffDstMemory(diff_y);
......
......@@ -29,6 +29,7 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
......@@ -36,11 +37,10 @@ class ScaleMKLDNNKernel : public framework::OpKernel<T> {
bool is_inplaced = x->IsSharedBufferWith(*out);
platform::ActivationMKLDNNHandler<T> handler(
mkldnn::algorithm::eltwise_linear, ctx, dev_ctx, ctx.GetPlace(), x,
ctx.InputName("X"), is_inplaced);
mkldnn::algorithm::eltwise_linear, ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(out);
auto activation_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册