diff --git a/paddle/fluid/operators/transpose_mkldnn_op.cc b/paddle/fluid/operators/transpose_mkldnn_op.cc index 2f133c9e251388e9e78a6a49ca66a45a56eef76e..e6df7028f540d0928e2bb0763bd4cfef12059665 100644 --- a/paddle/fluid/operators/transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/transpose_mkldnn_op.cc @@ -29,10 +29,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); - const bool is_test = ctx.Attr("is_test"); - PADDLE_ENFORCE( - is_test == true, - "TransposeMKLDNN works only for inference!. Set is_test = True"); auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -68,6 +64,57 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel { } }; +template +class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto* out_grad = + ctx.Input(framework::GradVarName("Out")); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + if (!x_grad) return; + + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + std::vector axis = ctx.Attr>("axis"); + std::vector reversed_axis(axis); + int ndims = axis.size(); + if (ndims == 1) { + x_grad->ShareDataWith(*out_grad); + return; + } + + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + + const T* out_grad_data = out_grad->data(); + x_grad->mutable_data(ctx.GetPlace()); + + std::vector nchw_tz = + paddle::framework::vectorize2int(out_grad->dims()); + + const std::string key = platform::TransposeMKLDNNHandler::GetHash( + nchw_tz, axis, ctx.op().Output(framework::GradVarName("X"))); + + platform::TransposeMKLDNNHandler handler(nchw_tz, reversed_axis, dev_ctx, + mkldnn_engine, key); + + auto transpose_src_memory_p = handler.AcquireSrcMemory( + out_grad->format(), platform::to_void_cast(out_grad_data)); + auto transpose_dst_memory_p = + handler.AcquireDstMemory(x_grad, ctx.GetPlace()); + auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p, + transpose_src_memory_p); + + std::vector pipeline; + pipeline.push_back(*transpose_p); + mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); + } +}; + } // namespace operators } // namespace paddle @@ -77,3 +124,8 @@ REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNOpKernel); REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace, ops::TransposeMKLDNNOpKernel); + +REGISTER_OP_KERNEL(transpose_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNGradOpKernel); +REGISTER_OP_KERNEL(transpose2_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::TransposeMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index b3b379d16ff099ba244fc92ed149a0089c2750e4..db14d350c7d92629873dfc5bc9181f651582e47c 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -79,10 +79,6 @@ class TransposeOp : public framework::OperatorWithKernel { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddAttr("is_test", - "(bool, default false) Set to true for inference only, false " - "for training. Some layers may run faster when this is true.") - .SetDefault(false); AddInput( "X", "(Tensor) The input tensor, tensors with rank up to 6 are supported."); @@ -147,6 +143,24 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("X"), x_dims); } } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.GetPlace(), layout_, library_); + } }; // FIXME(zcd): transpose2 adds an intermediate output(XShape) based on @@ -237,9 +251,19 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + layout_ = framework::DataLayout::kMKLDNN; + } +#endif return framework::OpKernelType( ctx.Input(framework::GradVarName("Out"))->type(), - ctx.device_context()); + ctx.GetPlace(), layout_, library_); } }; diff --git a/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py index 61ac8790112ceadfdef7b18aad70af77644581cd..0c201b9e4f48df94924a248d820ae2cf73367560 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_mkldnn_op.py @@ -23,16 +23,6 @@ class TestTransposeMKLDNN(TestTransposeOp): def init_op_type(self): self.op_type = "transpose2" self.use_mkldnn = True - self.is_test = True - return - - def test_check_grad(self): - return - - def test_check_grad_no_input(self): - return - - def test_check_grad_no_filter(self): return diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 93be9d28da7a73f4fa972acf0dbd95167e7dfca3..a38540a7240636415ef4703609c5a3e8e83ed1da 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -27,7 +27,6 @@ class TestTransposeOp(OpTest): self.attrs = { 'axis': list(self.axis), 'use_mkldnn': self.use_mkldnn, - 'is_test': self.is_test, } self.outputs = { 'XShape': np.random.random(self.shape).astype("float32"), @@ -37,7 +36,6 @@ class TestTransposeOp(OpTest): def init_op_type(self): self.op_type = "transpose2" self.use_mkldnn = False - self.is_test = False def test_check_output(self): self.check_output(no_check_set=['XShape'])