提交 f6e981f2 编写于 作者: J Jacek Czaja

- Binary is no longer manually cached

上级 5f0422f4
...@@ -47,13 +47,13 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -47,13 +47,13 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out"); float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler( BINARY_OP, axis, mkldnn_engine, ctx.GetPlace(), x, y, z, scale_x, scale_y, scale_o);
BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z,
scale_x, scale_y, scale_o, ctx.OutputName("Out"));
const auto src_x_memory = handler.AcquireSrcMemory(x); const auto src_x_memory = handler.AcquireSrcMemory(x);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_memory = handler.AcquireDstMemory(z); // For Inplace src and and dst are the same memory object
auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z);
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
......
...@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -48,9 +48,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
if (dx) { if (dx) {
// dx = dout*y // dx = dout*y
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, dnnl::algorithm::binary_mul, axis, mkldnn_engine,
ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f, ctx.GetPlace(), dout, y, dx, 1.0f, 1.0f, 1.0f);
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout); const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y); const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
...@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -75,9 +74,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// Handler is having nullptr passed instead of output tensor as // Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor // we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler( platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, dev_ctx, mkldnn_engine, dnnl::algorithm::binary_mul, axis, mkldnn_engine,
ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f, ctx.GetPlace(), dout, x, nullptr, 1.0f, 1.0f, 1.0f);
ctx.InputName(framework::GradVarName("Out")));
const auto src_dout_memory = handler.AcquireSrcMemory(dout); const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x); const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
......
...@@ -57,9 +57,8 @@ class SoftmaxMKLDNNHandler ...@@ -57,9 +57,8 @@ class SoftmaxMKLDNNHandler
platform::Place cpu_place, const Tensor* out, platform::Place cpu_place, const Tensor* out,
const Tensor* out_grad, Tensor* in_x_grad, const Tensor* out_grad, Tensor* in_x_grad,
const std::string& unique_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerNoCachingT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(mkldnn_engine, cpu_place) {
dev_ctx, mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_grad->dims(), in_x_grad->dims(), out_grad->dims(), in_x_grad->dims(),
platform::errors::InvalidArgument("The shape of softmax_grad's input " platform::errors::InvalidArgument("The shape of softmax_grad's input "
......
...@@ -49,18 +49,17 @@ class MKLDNNHandlerNoCachingT { ...@@ -49,18 +49,17 @@ class MKLDNNHandlerNoCachingT {
} }
std::shared_ptr<TForward> AcquireForwardPrimitive() { std::shared_ptr<TForward> AcquireForwardPrimitive() {
return forward_p = std::make_shared<TForward>(*fwd_pd_); return std::make_shared<TForward>(*fwd_pd_);
} }
std::shared_ptr<TBackward> AcquireBackwardPrimitive() { std::shared_ptr<TBackward> AcquireBackwardPrimitive() {
return backward_p = std::make_shared<TBackward>(*bwd_pd_); return std::make_shared<TBackward>(*bwd_pd_);
} }
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() { std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable( PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when " "Error: BWD_PD should be set when "
"getting BWD prim witk key: %s .", "getting BWD prim ."));
key_p));
return std::make_shared<TBackward_params>(*bwd_w_pd_); return std::make_shared<TBackward_params>(*bwd_w_pd_);
} }
...@@ -802,19 +801,13 @@ class MKLDNNHandler { ...@@ -802,19 +801,13 @@ class MKLDNNHandler {
}; };
template <typename T> template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> { class BinaryMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis, BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine engine, platform::Place cpu_place, const mkldnn::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z, const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z, float scale_x, float scale_y, float scale_z)
const std::string& uniq_name) : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN, x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor.")); platform::errors::InvalidArgument("Wrong layout set for X tensor."));
...@@ -859,7 +852,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> { ...@@ -859,7 +852,6 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md,
src1_md, dst_md); src1_md, dst_md);
} }
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册