未验证 提交 e427a0f1 编写于 作者: J jakpiase 提交者: GitHub

Added flatten and flatten2 BF16/FP32 FWD/BWD kernels (#35892)

* refactored reshape multiop kernel and added flatten1/2 kernels

* added formatting for flatten tests

* CI fix

* disabled reshape_kernel ops after succesful CI run

* minor fix
上级 ec2f68e8
...@@ -77,9 +77,17 @@ class FlattenOp : public framework::OperatorWithKernel { ...@@ -77,9 +77,17 @@ class FlattenOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X"), framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
ctx.device_context());
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -101,6 +109,14 @@ class FlattenOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,6 +109,14 @@ class FlattenOpMaker : public framework::OpProtoAndCheckerMaker {
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the" "tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
"input tensor is (d_0, d_1, ... d_n).") "input tensor is (d_0, d_1, ... d_n).")
.SetDefault(1); .SetDefault(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC( AddComment(R"DOC(
Flatten Operator Flatten Operator
...@@ -139,9 +155,17 @@ class FlattenGradOp : public framework::OperatorWithKernel { ...@@ -139,9 +155,17 @@ class FlattenGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out"));
ctx.device_context());
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -198,6 +222,21 @@ class Flatten2Op : public framework::OperatorWithKernel { ...@@ -198,6 +222,21 @@ class Flatten2Op : public framework::OperatorWithKernel {
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims)); ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape"); ctx->ShareLoD("X", "XShape");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class Flatten2OpMaker : public FlattenOpMaker { class Flatten2OpMaker : public FlattenOpMaker {
...@@ -244,9 +283,17 @@ class Flatten2GradOp : public framework::OperatorWithKernel { ...@@ -244,9 +283,17 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")), ctx, framework::GradVarName("Out"));
ctx.device_context());
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -12,9 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"
#include "paddle/fluid/operators/squeeze_op.h" #include "paddle/fluid/operators/squeeze_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace {
enum class ReshapeKernelOpName {
reshape,
reshape2,
squeeze,
squeeze2,
flatten,
flatten2,
};
} // anonymous namespace
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,7 +53,7 @@ static std::vector<int> extract_shape( ...@@ -41,7 +53,7 @@ static std::vector<int> extract_shape(
return vec_new_shape; return vec_new_shape;
} }
template <typename T> template <typename T, ReshapeKernelOpName op_name>
class ReshapeMKLDNNKernel : public framework::OpKernel<T> { class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -55,43 +67,13 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -55,43 +67,13 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
framework::DDim x_dims; framework::DDim x_dims, out_dims;
// if reshape or squeeze InferInOutShape(ctx, x_dims, out_dims);
if (ctx.Type().find("2") == std::string::npos) {
x_dims = x->dims();
} else {
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
auto x_vec_dims = framework::vectorize(x_dims); auto x_vec_dims = framework::vectorize(x_dims);
framework::DDim out_dims;
if (ctx.Type() == "squeeze") {
auto& axes = ctx.Attr<std::vector<int>>("axes");
out_dims = GetOutputShape(axes, x_dims, true);
} else {
out_dims = out->dims();
}
if (ctx.Type().find("reshape") != std::string::npos) {
auto list_new_shape_tensor = ctx.MultiInput<Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
auto new_shape = extract_shape(list_new_shape_tensor);
out_dims = ValidateShape(new_shape, x_dims);
} else if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ValidateShape(shape, x_dims);
}
}
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type()); mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(), platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(),
x_type, onednn_engine); x_type, onednn_engine);
...@@ -116,6 +98,104 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -116,6 +98,104 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
framework::vectorize(out_dims)))); framework::vectorize(out_dims))));
} }
void InferInOutShape(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
switch (op_name) {
case ReshapeKernelOpName::reshape:
InferShapeReshapeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::reshape2:
InferShapeReshape2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeSqueezeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeSqueeze2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"Reshape kernel doesn not support that operator name"));
}
}
void InferShapeReshapeOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
x_dims = x->dims();
out_dims = out->dims();
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
}
void InferShapeReshape2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* out = ctx.Output<LoDTensor>("Out");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
}
// in reshape1/2 ops "ShapeTensor" has highest priority and "Shape" has
// second highest priority
void ChangeReshapeOutDimsIfNeeded(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto list_new_shape_tensor = ctx.MultiInput<Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
auto new_shape = extract_shape(list_new_shape_tensor);
out_dims = ValidateShape(new_shape, x_dims);
} else if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ValidateShape(shape, x_dims);
}
}
void InferShapeSqueezeOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* x = ctx.Input<LoDTensor>("X");
x_dims = x->dims();
const auto& axes = ctx.Attr<std::vector<int>>("axes");
out_dims = GetOutputShape(axes, x_dims, true);
}
void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* out = ctx.Output<LoDTensor>("Out");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
}
void InferShapeFlattenOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto x = ctx.Input<LoDTensor>("X");
x_dims = x->dims();
auto axes = ctx.Attr<int>("axis");
out_dims = framework::make_ddim(
FlattenKernel<platform::CPUDeviceContext, float>::GetOutputShape(
axes, x_dims));
}
protected: protected:
static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) { static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size(); auto tensor_dims_size = tensor->dims().size();
...@@ -223,8 +303,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -223,8 +303,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T> template <typename T, ReshapeKernelOpName op_name>
class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> { class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
RunKernel(ctx); RunKernel(ctx);
...@@ -239,14 +319,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> { ...@@ -239,14 +319,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
auto* dout = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
framework::DDim x_dims; framework::DDim dx_dims;
// if reshape or squeeze InferOutputShapeInGrad(ctx, dx_dims);
if (ctx.Type().find("2") == std::string::npos) {
x_dims = dx->dims();
} else {
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
auto dout_vec_dims = framework::vectorize(dout->dims()); auto dout_vec_dims = framework::vectorize(dout->dims());
mkldnn::memory::data_type dout_type = mkldnn::memory::data_type dout_type =
...@@ -265,44 +340,128 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> { ...@@ -265,44 +340,128 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
dx->Resize(x_dims); dx->Resize(dx_dims);
dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
framework::vectorize(x_dims)))); framework::vectorize(dx_dims))));
} }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(squeeze, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze2, MKLDNN, paddle::platform::CPUPlace, void InferOutputShapeInGrad(const framework::ExecutionContext& ctx,
ops::ReshapeMKLDNNKernel<float>, framework::DDim& x_dims) const {
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>); switch (op_name) {
case ReshapeKernelOpName::reshape:
REGISTER_OP_KERNEL(squeeze2_grad, MKLDNN, paddle::platform::CPUPlace, InferShapeReshapeSqueezeGradOp(ctx, x_dims);
ops::ReshapeGradMKLDNNKernel<float>, break;
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>); case ReshapeKernelOpName::reshape2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"Reshape grad kernel doesn not support that operator name"));
}
}
REGISTER_OP_KERNEL(reshape, MKLDNN, paddle::platform::CPUPlace, void InferShapeReshapeSqueezeGradOp(const framework::ExecutionContext& ctx,
ops::ReshapeMKLDNNKernel<float>, framework::DDim& dx_dims) const {
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>); auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
dx_dims = dx->dims();
}
REGISTER_OP_KERNEL(reshape_grad, MKLDNN, paddle::platform::CPUPlace, void InferShapeReshape2Squeeze2Flatten2GradOp(
ops::ReshapeGradMKLDNNKernel<float>, const framework::ExecutionContext& ctx, framework::DDim& dx_dims) const {
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>); auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
dx_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
REGISTER_OP_KERNEL(reshape2, MKLDNN, paddle::platform::CPUPlace, void InferShapeFlattenGradOp(const framework::ExecutionContext& ctx,
ops::ReshapeMKLDNNKernel<float>, framework::DDim& dx_dims) const {
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>); dx_dims = ctx.Input<LoDTensor>("X")->dims();
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(reshape2_grad, MKLDNN, paddle::platform::CPUPlace, namespace ops = paddle::operators;
ops::ReshapeGradMKLDNNKernel<float>, REGISTER_OP_KERNEL(
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>); squeeze, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);
REGISTER_OP_KERNEL(
squeeze_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);
REGISTER_OP_KERNEL(
squeeze2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
squeeze2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
reshape, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape>);
REGISTER_OP_KERNEL(
reshape_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape>);
REGISTER_OP_KERNEL(
reshape2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape2>);
REGISTER_OP_KERNEL(
reshape2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape2>);
REGISTER_OP_KERNEL(
flatten, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten>);
REGISTER_OP_KERNEL(
flatten_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten>);
REGISTER_OP_KERNEL(
flatten2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten2>);
REGISTER_OP_KERNEL(
flatten2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten2>);
...@@ -248,13 +248,13 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -248,13 +248,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
...@@ -366,13 +366,13 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -366,13 +366,13 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -557,13 +557,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -557,13 +557,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -113,13 +113,13 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -113,13 +113,13 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -140,13 +140,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel { ...@@ -140,13 +140,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -241,13 +241,13 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -241,13 +241,13 @@ class Squeeze2Op : public framework::OperatorWithKernel {
auto input_data_type = auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -287,13 +287,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { ...@@ -287,13 +287,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN //#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(), // return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN, // framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN); // framework::LibraryType::kMKLDNN);
// } // }
#endif //#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
@OpTestTool.skip_if_not_cpu_bf16()
class TestFlattenOneDNNOp(OpTest):
def setUp(self):
self.set_op_type()
self.init_test_case()
self.set_inputs()
self.attrs = {"axis": self.axis, 'use_mkldnn': True}
self.ori_shape = self.inputs['X'].shape
self.outputs = {"Out": self.inputs["X"].copy().reshape(self.new_shape)}
def set_inputs(self):
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
def set_op_type(self):
self.op_type = "flatten"
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 1
self.new_shape = (3, 40)
class TestFlattenOneDNNOp1(TestFlattenOneDNNOp):
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 0
self.new_shape = (1, 120)
class TestFlattenOneDNNOpSixDims(TestFlattenOneDNNOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.axis = 4
self.new_shape = (36, 16)
class TestFlatten2OneDNNOp(TestFlattenOneDNNOp):
def set_op_type(self):
self.op_type = "flatten2"
class TestFlatten2OneDNNOp1(TestFlattenOneDNNOp1):
def set_op_type(self):
self.op_type = "flatten2"
class TestFlatten2OneDNNOpSixDims(TestFlattenOneDNNOpSixDims):
def set_op_type(self):
self.op_type = "flatten2"
# BF16 TESTS
def create_flatten_bf16_test_classes(parent):
class TestFlatten2BF16OneDNNOp(parent):
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {
"X": np.random.random(self.in_shape).astype("uint16")
}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_output(self):
self.check_output_with_place(
core.CPUPlace(), no_check_set=["XShape"])
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[self.dout])
cls_name = "{0}_{1}".format(parent.__name__, "Flatten2_BF16")
TestFlatten2BF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestFlatten2BF16OneDNNOp
class TestFlattenBF16OneDNNOp(parent):
def set_op_type(self):
self.dtype = np.uint16
self.op_type = "flatten"
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {
"X": np.random.random(self.in_shape).astype("uint16")
}
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
cls_name = "{0}_{1}".format(parent.__name__, "Flatten_BF16")
TestFlattenBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestFlattenBF16OneDNNOp
create_flatten_bf16_test_classes(TestFlatten2OneDNNOp)
create_flatten_bf16_test_classes(TestFlatten2OneDNNOp1)
create_flatten_bf16_test_classes(TestFlatten2OneDNNOpSixDims)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册