未验证 提交 88fc7a7d 编写于 作者: W Wojciech Uss 提交者: GitHub

fix cache key for inplaced elementwise ops (#30404)

上级 e5bb4edb
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
...@@ -45,19 +46,25 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -45,19 +46,25 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_x = ctx.Attr<float>("Scale_x"); float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y"); float scale_y = ctx.Attr<float>("Scale_y");
float scale_o = ctx.Attr<float>("Scale_out"); float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool is_inplaced = x->IsSharedBufferWith(*z);
std::string key = is_inplaced
? platform::CreateKey(dev_ctx, ctx.OutputName("Out"),
x->format(), y->format())
: ctx.OutputName("Out");
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(
BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z,
scale_x, scale_y, scale_o, ctx.OutputName("Out")); scale_x, scale_y, scale_o, key);
const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
const auto dst_memory = const auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z); is_inplaced ? src_x_memory : handler.AcquireDstMemory(z);
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册