diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 0858a43838b964f049a9df4b431cba6dfbe693f6..14f2e9061b742f002d2a6dbb1fa26d84ee81afc4 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -77,9 +77,17 @@ class FlattenOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -101,6 +109,14 @@ class FlattenOpMaker : public framework::OpProtoAndCheckerMaker { "tensor is (1, (d_0 X d_1 ... d_n), where the shape of the" "input tensor is (d_0, d_1, ... d_n).") .SetDefault(1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); AddComment(R"DOC( Flatten Operator @@ -139,9 +155,17 @@ class FlattenGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -198,6 +222,21 @@ class Flatten2Op : public framework::OperatorWithKernel { ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->ShareLoD("X", "XShape"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class Flatten2OpMaker : public FlattenOpMaker { @@ -244,9 +283,17 @@ class Flatten2GradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context()); + auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc index e6a7f3e74fcc7a22274c1e19d1d65926a0360e7f..6c3f4ec06201a115d50074a2d9c5fd9aa63743fa 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -12,9 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/flatten_op.h" #include "paddle/fluid/operators/squeeze_op.h" #include "paddle/fluid/platform/mkldnn_reuse.h" +namespace { +enum class ReshapeKernelOpName { + reshape, + reshape2, + squeeze, + squeeze2, + flatten, + flatten2, +}; +} // anonymous namespace + namespace paddle { namespace operators { @@ -41,7 +53,7 @@ static std::vector extract_shape( return vec_new_shape; } -template +template class ReshapeMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -55,43 +67,13 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { const auto& onednn_engine = dev_ctx.GetEngine(); auto* x = ctx.Input("X"); - auto* xshape = ctx.Output("XShape"); auto* out = ctx.Output("Out"); - framework::DDim x_dims; - // if reshape or squeeze - if (ctx.Type().find("2") == std::string::npos) { - x_dims = x->dims(); - } else { - auto xshape_dims = xshape->dims(); - x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - } + framework::DDim x_dims, out_dims; + InferInOutShape(ctx, x_dims, out_dims); auto x_vec_dims = framework::vectorize(x_dims); - framework::DDim out_dims; - if (ctx.Type() == "squeeze") { - auto& axes = ctx.Attr>("axes"); - out_dims = GetOutputShape(axes, x_dims, true); - } else { - out_dims = out->dims(); - } - - if (ctx.Type().find("reshape") != std::string::npos) { - auto list_new_shape_tensor = ctx.MultiInput("ShapeTensor"); - if (list_new_shape_tensor.size() > 0) { - auto new_shape = extract_shape(list_new_shape_tensor); - out_dims = ValidateShape(new_shape, x_dims); - } else if (ctx.HasInput("Shape")) { - auto* shape_tensor = ctx.Input("Shape"); - auto* shape_data = shape_tensor->data(); - - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ValidateShape(shape, x_dims); - } - } - mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(), x_type, onednn_engine); @@ -116,6 +98,104 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { framework::vectorize(out_dims)))); } + void InferInOutShape(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + switch (op_name) { + case ReshapeKernelOpName::reshape: + InferShapeReshapeOp(ctx, x_dims, out_dims); + break; + case ReshapeKernelOpName::reshape2: + InferShapeReshape2Op(ctx, x_dims, out_dims); + break; + case ReshapeKernelOpName::squeeze: + InferShapeSqueezeOp(ctx, x_dims, out_dims); + break; + case ReshapeKernelOpName::squeeze2: + InferShapeSqueeze2Op(ctx, x_dims, out_dims); + break; + case ReshapeKernelOpName::flatten: + InferShapeFlattenOp(ctx, x_dims, out_dims); + break; + case ReshapeKernelOpName::flatten2: + InferShapeFlattenOp(ctx, x_dims, out_dims); + break; + default: + PADDLE_THROW(paddle::platform::errors::OutOfRange( + "Reshape kernel doesn not support that operator name")); + } + } + + void InferShapeReshapeOp(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + x_dims = x->dims(); + out_dims = out->dims(); + ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims); + } + + void InferShapeReshape2Op(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto* out = ctx.Output("Out"); + auto* xshape = ctx.Output("XShape"); + auto xshape_dims = xshape->dims(); + x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + out_dims = out->dims(); + ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims); + } + + // in reshape1/2 ops "ShapeTensor" has highest priority and "Shape" has + // second highest priority + void ChangeReshapeOutDimsIfNeeded(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto list_new_shape_tensor = ctx.MultiInput("ShapeTensor"); + if (list_new_shape_tensor.size() > 0) { + auto new_shape = extract_shape(list_new_shape_tensor); + out_dims = ValidateShape(new_shape, x_dims); + } else if (ctx.HasInput("Shape")) { + auto* shape_tensor = ctx.Input("Shape"); + auto* shape_data = shape_tensor->data(); + + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ValidateShape(shape, x_dims); + } + } + + void InferShapeSqueezeOp(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto* x = ctx.Input("X"); + x_dims = x->dims(); + const auto& axes = ctx.Attr>("axes"); + out_dims = GetOutputShape(axes, x_dims, true); + } + + void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto* out = ctx.Output("Out"); + auto* xshape = ctx.Output("XShape"); + auto xshape_dims = xshape->dims(); + x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + out_dims = out->dims(); + } + + void InferShapeFlattenOp(const framework::ExecutionContext& ctx, + framework::DDim& x_dims, + framework::DDim& out_dims) const { + auto x = ctx.Input("X"); + x_dims = x->dims(); + auto axes = ctx.Attr("axis"); + out_dims = framework::make_ddim( + FlattenKernel::GetOutputShape( + axes, x_dims)); + } + protected: static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) { auto tensor_dims_size = tensor->dims().size(); @@ -223,8 +303,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { } }; -template -class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { +template +class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { RunKernel(ctx); @@ -239,14 +319,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - framework::DDim x_dims; - // if reshape or squeeze - if (ctx.Type().find("2") == std::string::npos) { - x_dims = dx->dims(); - } else { - auto xshape_dims = ctx.Input("XShape")->dims(); - x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - } + framework::DDim dx_dims; + InferOutputShapeInGrad(ctx, dx_dims); + auto dout_vec_dims = framework::vectorize(dout->dims()); mkldnn::memory::data_type dout_type = @@ -265,44 +340,128 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); - dx->Resize(x_dims); + dx->Resize(dx_dims); dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( - framework::vectorize(x_dims)))); + framework::vectorize(dx_dims)))); } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(squeeze, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeMKLDNNKernel, - ops::ReshapeMKLDNNKernel); - -REGISTER_OP_KERNEL(squeeze_grad, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeGradMKLDNNKernel, - ops::ReshapeGradMKLDNNKernel); -REGISTER_OP_KERNEL(squeeze2, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeMKLDNNKernel, - ops::ReshapeMKLDNNKernel); - -REGISTER_OP_KERNEL(squeeze2_grad, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeGradMKLDNNKernel, - ops::ReshapeGradMKLDNNKernel); + void InferOutputShapeInGrad(const framework::ExecutionContext& ctx, + framework::DDim& x_dims) const { + switch (op_name) { + case ReshapeKernelOpName::reshape: + InferShapeReshapeSqueezeGradOp(ctx, x_dims); + break; + case ReshapeKernelOpName::reshape2: + InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); + break; + case ReshapeKernelOpName::squeeze: + InferShapeReshapeSqueezeGradOp(ctx, x_dims); + break; + case ReshapeKernelOpName::squeeze2: + InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); + break; + case ReshapeKernelOpName::flatten: + InferShapeFlattenGradOp(ctx, x_dims); + break; + case ReshapeKernelOpName::flatten2: + InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims); + break; + default: + PADDLE_THROW(paddle::platform::errors::OutOfRange( + "Reshape grad kernel doesn not support that operator name")); + } + } -REGISTER_OP_KERNEL(reshape, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeMKLDNNKernel, - ops::ReshapeMKLDNNKernel); + void InferShapeReshapeSqueezeGradOp(const framework::ExecutionContext& ctx, + framework::DDim& dx_dims) const { + auto* dx = ctx.Output(framework::GradVarName("X")); + dx_dims = dx->dims(); + } -REGISTER_OP_KERNEL(reshape_grad, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeGradMKLDNNKernel, - ops::ReshapeGradMKLDNNKernel); + void InferShapeReshape2Squeeze2Flatten2GradOp( + const framework::ExecutionContext& ctx, framework::DDim& dx_dims) const { + auto xshape_dims = ctx.Input("XShape")->dims(); + dx_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + } -REGISTER_OP_KERNEL(reshape2, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeMKLDNNKernel, - ops::ReshapeMKLDNNKernel); + void InferShapeFlattenGradOp(const framework::ExecutionContext& ctx, + framework::DDim& dx_dims) const { + dx_dims = ctx.Input("X")->dims(); + } +}; +} // namespace operators +} // namespace paddle -REGISTER_OP_KERNEL(reshape2_grad, MKLDNN, paddle::platform::CPUPlace, - ops::ReshapeGradMKLDNNKernel, - ops::ReshapeGradMKLDNNKernel); +namespace ops = paddle::operators; +REGISTER_OP_KERNEL( + squeeze, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + squeeze_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL( + squeeze2, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + squeeze2_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL( + reshape, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + reshape_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL( + reshape2, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + reshape2_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL( + flatten, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + flatten_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); + +REGISTER_OP_KERNEL( + flatten2, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeMKLDNNKernel, + ops::ReshapeMKLDNNKernel); + +REGISTER_OP_KERNEL( + flatten2_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReshapeGradMKLDNNKernel, + ops::ReshapeGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index c74f0f0e499b44464ee650f7e10c4096adcb692c..6f244b1a4cb8fe43b0acff27c67ee08ca440445a 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -248,13 +248,13 @@ class ReshapeOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -366,13 +366,13 @@ class ReshapeGradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -557,13 +557,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 8894ca650de034f8c7f9c0d48d76faa33d164d05..de30eab25f3cf2c37a2f43cc6e11490b01075229 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -113,13 +113,13 @@ class SqueezeOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -140,13 +140,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -241,13 +241,13 @@ class Squeeze2Op : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -287,13 +287,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); -#ifdef PADDLE_WITH_MKLDNN -// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { -// return framework::OpKernelType(input_data_type, ctx.GetPlace(), -// framework::DataLayout::kMKLDNN, -// framework::LibraryType::kMKLDNN); -// } -#endif + //#ifdef PADDLE_WITH_MKLDNN + // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + // return framework::OpKernelType(input_data_type, ctx.GetPlace(), + // framework::DataLayout::kMKLDNN, + // framework::LibraryType::kMKLDNN); + // } + //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_flatten_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_flatten_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c01f244004effb260c1c89ede87477bdf5735aca --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_flatten_mkldnn_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core + +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestFlattenOneDNNOp(OpTest): + def setUp(self): + self.set_op_type() + self.init_test_case() + self.set_inputs() + self.attrs = {"axis": self.axis, 'use_mkldnn': True} + self.ori_shape = self.inputs['X'].shape + self.outputs = {"Out": self.inputs["X"].copy().reshape(self.new_shape)} + + def set_inputs(self): + self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} + + def set_op_type(self): + self.op_type = "flatten" + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 2, 10) + self.axis = 1 + self.new_shape = (3, 40) + + +class TestFlattenOneDNNOp1(TestFlattenOneDNNOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 10) + self.axis = 0 + self.new_shape = (1, 120) + + +class TestFlattenOneDNNOpSixDims(TestFlattenOneDNNOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.axis = 4 + self.new_shape = (36, 16) + + +class TestFlatten2OneDNNOp(TestFlattenOneDNNOp): + def set_op_type(self): + self.op_type = "flatten2" + + +class TestFlatten2OneDNNOp1(TestFlattenOneDNNOp1): + def set_op_type(self): + self.op_type = "flatten2" + + +class TestFlatten2OneDNNOpSixDims(TestFlattenOneDNNOpSixDims): + def set_op_type(self): + self.op_type = "flatten2" + + +# BF16 TESTS +def create_flatten_bf16_test_classes(parent): + class TestFlatten2BF16OneDNNOp(parent): + def set_inputs(self): + self.dtype = np.uint16 + self.inputs = { + "X": np.random.random(self.in_shape).astype("uint16") + } + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = np.reshape(self.dout, self.ori_shape) + + def test_check_output(self): + self.check_output_with_place( + core.CPUPlace(), no_check_set=["XShape"]) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[self.dout]) + + cls_name = "{0}_{1}".format(parent.__name__, "Flatten2_BF16") + TestFlatten2BF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestFlatten2BF16OneDNNOp + + class TestFlattenBF16OneDNNOp(parent): + def set_op_type(self): + self.dtype = np.uint16 + self.op_type = "flatten" + + def set_inputs(self): + self.dtype = np.uint16 + self.inputs = { + "X": np.random.random(self.in_shape).astype("uint16") + } + + def set_outputs(self): + self.outputs = {"Out": self.x.reshape(self.new_shape)} + + def calculate_grads(self): + self.dout = self.outputs['Out'] + self.dx = np.reshape(self.dout, self.ori_shape) + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + user_defined_grads=[self.dx], + user_defined_grad_outputs=[convert_float_to_uint16(self.dout)]) + + cls_name = "{0}_{1}".format(parent.__name__, "Flatten_BF16") + TestFlattenBF16OneDNNOp.__name__ = cls_name + globals()[cls_name] = TestFlattenBF16OneDNNOp + + +create_flatten_bf16_test_classes(TestFlatten2OneDNNOp) +create_flatten_bf16_test_classes(TestFlatten2OneDNNOp1) +create_flatten_bf16_test_classes(TestFlatten2OneDNNOpSixDims) + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()