未验证 提交 e427a0f1 编写于 作者: J jakpiase 提交者: GitHub

Added flatten and flatten2 BF16/FP32 FWD/BWD kernels (#35892)

* refactored reshape multiop kernel and added flatten1/2 kernels

* added formatting for flatten tests

* CI fix

* disabled reshape_kernel ops after succesful CI run

* minor fix
上级 ec2f68e8
......@@ -77,9 +77,17 @@ class FlattenOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -101,6 +109,14 @@ class FlattenOpMaker : public framework::OpProtoAndCheckerMaker {
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
"input tensor is (d_0, d_1, ... d_n).")
.SetDefault(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddComment(R"DOC(
Flatten Operator
......@@ -139,9 +155,17 @@ class FlattenGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -198,6 +222,21 @@ class Flatten2Op : public framework::OperatorWithKernel {
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
ctx->ShareLoD("X", "XShape");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class Flatten2OpMaker : public FlattenOpMaker {
......@@ -244,9 +283,17 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -12,9 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flatten_op.h"
#include "paddle/fluid/operators/squeeze_op.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace {
enum class ReshapeKernelOpName {
reshape,
reshape2,
squeeze,
squeeze2,
flatten,
flatten2,
};
} // anonymous namespace
namespace paddle {
namespace operators {
......@@ -41,7 +53,7 @@ static std::vector<int> extract_shape(
return vec_new_shape;
}
template <typename T>
template <typename T, ReshapeKernelOpName op_name>
class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
......@@ -55,43 +67,13 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<LoDTensor>("X");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto* out = ctx.Output<LoDTensor>("Out");
framework::DDim x_dims;
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = x->dims();
} else {
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
framework::DDim x_dims, out_dims;
InferInOutShape(ctx, x_dims, out_dims);
auto x_vec_dims = framework::vectorize(x_dims);
framework::DDim out_dims;
if (ctx.Type() == "squeeze") {
auto& axes = ctx.Attr<std::vector<int>>("axes");
out_dims = GetOutputShape(axes, x_dims, true);
} else {
out_dims = out->dims();
}
if (ctx.Type().find("reshape") != std::string::npos) {
auto list_new_shape_tensor = ctx.MultiInput<Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
auto new_shape = extract_shape(list_new_shape_tensor);
out_dims = ValidateShape(new_shape, x_dims);
} else if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ValidateShape(shape, x_dims);
}
}
mkldnn::memory::data_type x_type = framework::ToMKLDNNDataType(x->type());
platform::ReorderMKLDNNHandler reorder_handler(x_vec_dims, x->type(),
x_type, onednn_engine);
......@@ -116,6 +98,104 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
framework::vectorize(out_dims))));
}
void InferInOutShape(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
switch (op_name) {
case ReshapeKernelOpName::reshape:
InferShapeReshapeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::reshape2:
InferShapeReshape2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeSqueezeOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeSqueeze2Op(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeFlattenOp(ctx, x_dims, out_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"Reshape kernel doesn not support that operator name"));
}
}
void InferShapeReshapeOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
x_dims = x->dims();
out_dims = out->dims();
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
}
void InferShapeReshape2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* out = ctx.Output<LoDTensor>("Out");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
ChangeReshapeOutDimsIfNeeded(ctx, x_dims, out_dims);
}
// in reshape1/2 ops "ShapeTensor" has highest priority and "Shape" has
// second highest priority
void ChangeReshapeOutDimsIfNeeded(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto list_new_shape_tensor = ctx.MultiInput<Tensor>("ShapeTensor");
if (list_new_shape_tensor.size() > 0) {
auto new_shape = extract_shape(list_new_shape_tensor);
out_dims = ValidateShape(new_shape, x_dims);
} else if (ctx.HasInput("Shape")) {
auto* shape_tensor = ctx.Input<framework::LoDTensor>("Shape");
auto* shape_data = shape_tensor->data<int>();
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ValidateShape(shape, x_dims);
}
}
void InferShapeSqueezeOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* x = ctx.Input<LoDTensor>("X");
x_dims = x->dims();
const auto& axes = ctx.Attr<std::vector<int>>("axes");
out_dims = GetOutputShape(axes, x_dims, true);
}
void InferShapeSqueeze2Op(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto* out = ctx.Output<LoDTensor>("Out");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
out_dims = out->dims();
}
void InferShapeFlattenOp(const framework::ExecutionContext& ctx,
framework::DDim& x_dims,
framework::DDim& out_dims) const {
auto x = ctx.Input<LoDTensor>("X");
x_dims = x->dims();
auto axes = ctx.Attr<int>("axis");
out_dims = framework::make_ddim(
FlattenKernel<platform::CPUDeviceContext, float>::GetOutputShape(
axes, x_dims));
}
protected:
static mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) {
auto tensor_dims_size = tensor->dims().size();
......@@ -223,8 +303,8 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
template <typename T, ReshapeKernelOpName op_name>
class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
RunKernel(ctx);
......@@ -239,14 +319,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
auto* dout = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
framework::DDim x_dims;
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = dx->dims();
} else {
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
framework::DDim dx_dims;
InferOutputShapeInGrad(ctx, dx_dims);
auto dout_vec_dims = framework::vectorize(dout->dims());
mkldnn::memory::data_type dout_type =
......@@ -265,44 +340,128 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->Resize(x_dims);
dx->Resize(dx_dims);
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape(
framework::vectorize(x_dims))));
framework::vectorize(dx_dims))));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(squeeze, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(squeeze2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
void InferOutputShapeInGrad(const framework::ExecutionContext& ctx,
framework::DDim& x_dims) const {
switch (op_name) {
case ReshapeKernelOpName::reshape:
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::reshape2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze:
InferShapeReshapeSqueezeGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::squeeze2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten:
InferShapeFlattenGradOp(ctx, x_dims);
break;
case ReshapeKernelOpName::flatten2:
InferShapeReshape2Squeeze2Flatten2GradOp(ctx, x_dims);
break;
default:
PADDLE_THROW(paddle::platform::errors::OutOfRange(
"Reshape grad kernel doesn not support that operator name"));
}
}
REGISTER_OP_KERNEL(reshape, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
void InferShapeReshapeSqueezeGradOp(const framework::ExecutionContext& ctx,
framework::DDim& dx_dims) const {
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
dx_dims = dx->dims();
}
REGISTER_OP_KERNEL(reshape_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
void InferShapeReshape2Squeeze2Flatten2GradOp(
const framework::ExecutionContext& ctx, framework::DDim& dx_dims) const {
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
dx_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
REGISTER_OP_KERNEL(reshape2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16>);
void InferShapeFlattenGradOp(const framework::ExecutionContext& ctx,
framework::DDim& dx_dims) const {
dx_dims = ctx.Input<LoDTensor>("X")->dims();
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(reshape2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16>);
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
squeeze, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);
REGISTER_OP_KERNEL(
squeeze_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze>);
REGISTER_OP_KERNEL(
squeeze2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
squeeze2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::squeeze2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::squeeze2>);
REGISTER_OP_KERNEL(
reshape, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape>);
REGISTER_OP_KERNEL(
reshape_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape>);
REGISTER_OP_KERNEL(
reshape2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape2>);
REGISTER_OP_KERNEL(
reshape2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::reshape2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::reshape2>);
REGISTER_OP_KERNEL(
flatten, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten>);
REGISTER_OP_KERNEL(
flatten_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten>);
REGISTER_OP_KERNEL(
flatten2, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
ops::ReshapeMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten2>);
REGISTER_OP_KERNEL(
flatten2_grad, MKLDNN, paddle::platform::CPUPlace,
ops::ReshapeGradMKLDNNKernel<float, ReshapeKernelOpName::flatten2>,
ops::ReshapeGradMKLDNNKernel<paddle::platform::bfloat16,
ReshapeKernelOpName::flatten2>);
......@@ -248,13 +248,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......@@ -366,13 +366,13 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -557,13 +557,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -113,13 +113,13 @@ class SqueezeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -140,13 +140,13 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -241,13 +241,13 @@ class Squeeze2Op : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......@@ -287,13 +287,13 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
#endif
//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
@OpTestTool.skip_if_not_cpu_bf16()
class TestFlattenOneDNNOp(OpTest):
def setUp(self):
self.set_op_type()
self.init_test_case()
self.set_inputs()
self.attrs = {"axis": self.axis, 'use_mkldnn': True}
self.ori_shape = self.inputs['X'].shape
self.outputs = {"Out": self.inputs["X"].copy().reshape(self.new_shape)}
def set_inputs(self):
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
def set_op_type(self):
self.op_type = "flatten"
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 1
self.new_shape = (3, 40)
class TestFlattenOneDNNOp1(TestFlattenOneDNNOp):
def init_test_case(self):
self.in_shape = (3, 2, 2, 10)
self.axis = 0
self.new_shape = (1, 120)
class TestFlattenOneDNNOpSixDims(TestFlattenOneDNNOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
self.axis = 4
self.new_shape = (36, 16)
class TestFlatten2OneDNNOp(TestFlattenOneDNNOp):
def set_op_type(self):
self.op_type = "flatten2"
class TestFlatten2OneDNNOp1(TestFlattenOneDNNOp1):
def set_op_type(self):
self.op_type = "flatten2"
class TestFlatten2OneDNNOpSixDims(TestFlattenOneDNNOpSixDims):
def set_op_type(self):
self.op_type = "flatten2"
# BF16 TESTS
def create_flatten_bf16_test_classes(parent):
class TestFlatten2BF16OneDNNOp(parent):
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {
"X": np.random.random(self.in_shape).astype("uint16")
}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_output(self):
self.check_output_with_place(
core.CPUPlace(), no_check_set=["XShape"])
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[self.dout])
cls_name = "{0}_{1}".format(parent.__name__, "Flatten2_BF16")
TestFlatten2BF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestFlatten2BF16OneDNNOp
class TestFlattenBF16OneDNNOp(parent):
def set_op_type(self):
self.dtype = np.uint16
self.op_type = "flatten"
def set_inputs(self):
self.dtype = np.uint16
self.inputs = {
"X": np.random.random(self.in_shape).astype("uint16")
}
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def calculate_grads(self):
self.dout = self.outputs['Out']
self.dx = np.reshape(self.dout, self.ori_shape)
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["X"],
"Out",
user_defined_grads=[self.dx],
user_defined_grad_outputs=[convert_float_to_uint16(self.dout)])
cls_name = "{0}_{1}".format(parent.__name__, "Flatten_BF16")
TestFlattenBF16OneDNNOp.__name__ = cls_name
globals()[cls_name] = TestFlattenBF16OneDNNOp
create_flatten_bf16_test_classes(TestFlatten2OneDNNOp)
create_flatten_bf16_test_classes(TestFlatten2OneDNNOp1)
create_flatten_bf16_test_classes(TestFlatten2OneDNNOpSixDims)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册