提交 f6e981f2 编写于 作者: J Jacek Czaja

- Binary is no longer manually cached

上级 5f0422f4
......@@ -47,13 +47,13 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis");
platform::BinaryMKLDNNHandler<T> handler(
BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z,
scale_x, scale_y, scale_o, ctx.OutputName("Out"));
platform::BinaryMKLDNNHandler<T> handler( BINARY_OP, axis, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x, scale_y, scale_o);
const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_memory = handler.AcquireDstMemory(z);
// For Inplace src and and dst are the same memory object
auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z);
const auto binary_prim = handler.AcquireForwardPrimitive();
......
......@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
dnnl::algorithm::binary_mul, axis, mkldnn_engine,
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
......@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine,
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f,
ctx.InputName(framework::GradVarName("Out")));
dnnl::algorithm::binary_mul, axis, mkldnn_engine,
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
......
......@@ -57,9 +57,8 @@ class SoftmaxMKLDNNHandler
platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place) {
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(
out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input "
......
......@@ -49,18 +49,17 @@ class MKLDNNHandlerNoCachingT {
}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
return forward_p = std::make_shared<TForward>(*fwd_pd_);
return std::make_shared<TForward>(*fwd_pd_);
}
std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return backward_p = std::make_shared<TBackward>(*bwd_pd_);
return std::make_shared<TBackward>(*bwd_pd_);
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
"getting BWD prim ."));
return std::make_shared<TBackward_params>(*bwd_w_pd_);
}
......@@ -802,19 +801,13 @@ class MKLDNNHandler {
};
template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z,
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
float scale_x, float scale_y, float scale_z)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
......@@ -859,7 +852,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, dst_md);
}
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册