提交 9b5a2831 编写于 作者: T TensorFlower Gardener

Merge pull request #24086 from Intel-tensorflow:nhasabni/fusedconv

PiperOrigin-RevId: 225099426
......@@ -260,6 +260,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.fused_conv2d = "_FusedConv2D";
csinfo_.identity = "Identity";
csinfo_.lrn = "LRN";
csinfo_.lrn_grad = "LRNGrad";
......@@ -274,6 +275,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_grad_filter_with_bias =
"_MklConv2DBackpropFilterWithBias";
csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
csinfo_.pad = "Pad";
csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
......@@ -380,6 +382,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.fused_batch_norm_grad,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d,
CopyAttrsFusedConv2D, FusedConv2DRewrite});
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsDataType, AlwaysRewrite});
......@@ -665,6 +669,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv3d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
string fused_conv2d;
string identity;
string lrn;
string lrn_grad;
......@@ -679,6 +684,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string mkl_conv2d_grad_filter;
string mkl_conv2d_grad_filter_with_bias;
string mkl_conv2d_with_bias;
string mkl_fused_conv2d;
string mkl_pad_with_conv2d;
string mul;
string pad;
......@@ -1174,6 +1180,23 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}
static bool FusedConv2DRewrite(const Node* n) {
// MKL DNN currently doesn't support all fusions that grappler fuses
// together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
// it includes those we support.
DataType T;
if (!GetNodeAttr(n->def(), "T", &T).ok() ||
!mkl_op_registry::IsMklOp(csinfo_.mkl_fused_conv2d, T)) {
return false;
}
std::vector<string> fused_ops;
TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
return (fused_ops == std::vector<string>{"BiasAdd"} ||
fused_ops == std::vector<string>{"Relu"} ||
fused_ops == std::vector<string>{"BiasAdd", "Relu"});
}
// Rewrites input node to a new node specified by its matching rewrite info.
//
// Method first searches matching rewrite info for input node and then
......@@ -1335,6 +1358,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
bool change_format = false);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsFusedConv2D(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
static void CopyAttrsPadWithConv2D(const Node* orig_node, NodeBuilder* nb,
......@@ -1554,12 +1579,13 @@ int MklLayoutRewritePass::SetUpContiguousInputs(
CHECK_NOTNULL(filter_node);
// Now check which nodes receive from filter_node. Filter feeds as
// 2nd input (slot 1) of _MklConv2D and _MklConv2DWithBias.
// 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
// _MklFusedConv2D.
for (const Edge* e : filter_node->out_edges()) {
if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
// add check for mkl_pad_with_conv2d
e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias) &&
e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
e->dst_input() == kConv2DFilterInputSlotIdx
/* filter is 2nd input of Conv2D and _MklConv2D. */) {
if (conv2d_node != nullptr) {
......@@ -2234,6 +2260,39 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
nb->Attr("is_training", is_training);
}
void MklLayoutRewritePass::CopyAttrsFusedConv2D(const Node* orig_node,
NodeBuilder* nb,
bool change_format) {
DataType T;
int num_args;
float epsilon;
string data_format;
string padding;
std::vector<int32> strides;
std::vector<int32> dilations;
std::vector<string> fused_ops;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
// Add attributes to new node.
nb->Attr("T", T);
nb->Attr("num_args", num_args);
nb->Attr("strides", strides);
nb->Attr("padding", padding);
nb->Attr("data_format", data_format);
nb->Attr("dilations", dilations);
nb->Attr("fused_ops", fused_ops);
nb->Attr("epsilon", epsilon);
}
//////////////////////////////////////////////////////////////////////////
// Helper functions related to node merge pass
//////////////////////////////////////////////////////////////////////////
......@@ -2881,6 +2940,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
if (n->type_string() != csinfo_.conv2d_with_bias &&
n->type_string() != csinfo_.pad_with_conv2d &&
n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
n->type_string() != csinfo_.fused_conv2d &&
!mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
T)) {
return nullptr;
......
......@@ -133,6 +133,7 @@ REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("DoubleInput").Output("o: double").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2")
.Output("o: uint8")
......@@ -142,7 +143,7 @@ REGISTER_OP("Output2").Input("i: float").Input("i1: float").SetIsStateful();
REGISTER_OP("Output").Input("i: float").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization
// Unit tests related to node merge optimization
/////////////////////////////////////////////////////////////////////
TEST_F(MklLayoutPassTest, Basic) {
......@@ -1096,6 +1097,131 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
"A->C;B->C:1;B->D;C->D:1");
}
// Rewrite test for _FusedConv2D Op with BiasAdd fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with Relu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with BiasAdd+Relu fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Positive3) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops'"
" value { list: {s: 'BiasAdd', s: 'Relu'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklFusedConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;"
"A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;"
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
// Rewrite test for _FusedConv2D Op with unsupported fusion
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_FusedConv2D);E(Zeta)|A->D;"
"B->D:1;C->D:2;C->E:1;D->E");
}
// Rewrite test for _FusedConv2D Op with unsupported type
TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative2) {
InitGraph(
"node { name: 'A' op: 'DoubleInput'}"
"node { name: 'B' op: 'DoubleInput'}"
"node { name: 'C' op: 'DoubleInput'}"
"node { name: 'D' op: '_FusedConv2D'"
" attr { key: 'T' value { type: DT_DOUBLE } }"
" attr { key: 'num_args' value { i: 1 } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }"
" attr { key: 'epsilon' value { f: 0.001 }}"
" input: ['A', 'B', 'C']}"
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_DOUBLE } }"
" input: ['D', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(DoubleInput);B(DoubleInput);C(DoubleInput);"
"D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E");
}
TEST_F(MklLayoutPassTest, NodeRewrite_Conv2DGradFilter_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
......
......@@ -1022,7 +1022,7 @@ class MklConvOp : public OpKernel {
// get a conv2d fwd from primitive pool
MklConvFwdPrimitive<float, Tinput, Tfilter, Tbias, Ttemp_output>*
conv_fwd = nullptr;
if (biasEnabled) {
if (fuse_biasadd_) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
......@@ -1094,7 +1094,7 @@ class MklConvOp : public OpKernel {
}
// execute convolution
if (biasEnabled) {
if (fuse_biasadd_) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
Tbias* bias_data =
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
......@@ -1154,6 +1154,12 @@ class MklConvOp : public OpKernel {
}
protected:
void set_fuse_biasadd(bool fuse_biasadd) { fuse_biasadd_ = fuse_biasadd; }
void set_fuse_relu(bool fuse_relu) { fuse_relu_ = fuse_relu; }
// This method is for the base class MklConvOp, which handles the
// floating point implementation of Conv. The quantized conv implementations
// will use overidden versions of this method.
virtual void ExtendConvFwdParams(OpKernelContext* context,
MklConvFwdParams& params) {
// Create a string from data types of input, filter, bias, and output.
......@@ -1161,6 +1167,11 @@ class MklConvOp : public OpKernel {
params.dtypes.append(typeid(Tfilter).name());
params.dtypes.append(typeid(Tbias).name());
params.dtypes.append(typeid(Toutput).name());
// Add fusions as post ops
// Note: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking fuse_biasadd_ flag.
if (fuse_relu_) params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
}
virtual Tbias* GetBiasHandle(
......@@ -1168,7 +1179,7 @@ class MklConvOp : public OpKernel {
std::shared_ptr<mkldnn::convolution_forward::primitive_desc>&
conv2d_fwd_pd,
const Tensor& bias_tensor) {
if (biasEnabled) {
if (fuse_biasadd_) {
return static_cast<Tbias*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
} else {
......@@ -1214,6 +1225,11 @@ class MklConvOp : public OpKernel {
std::vector<int32> dilations_;
Padding padding_;
TensorFormat data_format_;
// Initialize to values the template is instantiated with
bool fuse_biasadd_ = biasEnabled;
bool fuse_relu_ = false;
const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
const int kInputIndex_Pad = 2;
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
......@@ -1267,12 +1283,12 @@ class MklConvOp : public OpKernel {
// Create convolution primitive and add it to net.
std::vector<primitive> net;
if (bias) {
DCHECK(biasEnabled);
DCHECK(fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(), bias->GetOpMem(),
output->GetOpMem()));
} else {
DCHECK(!biasEnabled);
DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(),
output->GetOpMem()));
......@@ -1282,6 +1298,49 @@ class MklConvOp : public OpKernel {
}
};
// Base class for fused convolution forward operations
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
typename Toutput, typename Ttemp_output>
class MklFusedConvOp : public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput,
Ttemp_output, int32, false, false> {
public:
explicit MklFusedConvOp(OpKernelConstruction* context)
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output, int32,
false, false>(context) {
// Since we came here through the registration of _MklFusedConv2D, get
// all information from 'fused_ops' and 'num_args'
std::vector<string> fused_ops;
OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
int num_args;
OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
OP_REQUIRES(context, !fused_ops.empty(),
errors::InvalidArgument(
"Fused Conv2D must have at least one fused op."));
if (fused_ops == std::vector<string>{"BiasAdd"}) {
this->set_fuse_biasadd(true);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"Relu"}) {
this->set_fuse_relu(true);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_relu(true);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else {
OP_REQUIRES(context, false,
errors::Unimplemented("Fusion is not implemented: [",
str_util::Join(fused_ops, ","), "]"));
}
}
virtual ~MklFusedConvOp() {}
};
// We create new class for each verison of Quantized Convolution and inherit
// from the FP32 version of the base class
template <typename Device, typename Tbias, typename Toutput,
......@@ -1881,6 +1940,16 @@ REGISTER_KERNEL_BUILDER(
TF_CALL_float(REGISTER_MKL_CPU_2D);
#define REGISTER_MKL_CPU_2D_FUSED(T) \
REGISTER_KERNEL_BUILDER(Name("_MklFusedConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklFusedConvOp<CPUDevice, T, T, T, T, T>);
// We check the fused_ops attributes to decide if bias is enabled or not.
TF_CALL_float(REGISTER_MKL_CPU_2D_FUSED);
// Register 3D operations
#define REGISTER_MKL_CPU_3D(T) \
REGISTER_KERNEL_BUILDER( \
......
......@@ -32,17 +32,17 @@ limitations under the License.
namespace tensorflow {
// Helper class for converting MKL tesnors to TF tensors and comparing to
// Helper class for converting MKL tensors to TF tensors and comparing to
// expected values
static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
static const TensorShape dummy_shape({8});
template <typename T>
class ConvMklToTF : public OpsTestBase {
public:
template <typename T>
void ConvertAndCompare(DataType dtype, const Tensor& first,
const Tensor& second, const Tensor& expected) {
void PerformConversion(DataType dtype, const Tensor& tensor,
const Tensor& mkl_meta_tensor, Tensor* output) {
// Create an MKL to TF conversion node and execute it
TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
.Input(FakeInput(dtype)) // Input
......@@ -51,16 +51,259 @@ class ConvMklToTF : public OpsTestBase {
.Attr("_kernel", "MklOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(first.shape(), first.flat<T>());
AddInputFromArray<uint8>(second.shape(), second.flat<uint8>());
AddInputFromArray<T>(tensor.shape(), tensor.flat<T>());
AddInputFromArray<uint8>(mkl_meta_tensor.shape(),
mkl_meta_tensor.flat<uint8>());
TF_ASSERT_OK(RunOpKernel());
const Tensor& output = *GetOutput(0);
*output = *GetOutput(0);
}
void ConvertAndCompare(DataType dtype, const Tensor& tensor,
const Tensor& mkl_meta_tensor,
const Tensor& expected) {
Tensor output;
PerformConversion(dtype, tensor, mkl_meta_tensor, &output);
test::ExpectTensorNear<T>(expected, output, 1e-5);
}
void TestBody(){};
void TestBody() {}
};
// Testing MKL's fused convolution ops
template <typename T>
class MklFusedConv2DOpTest : public OpsTestBase {
protected:
static constexpr int kDepth = 3;
static constexpr int kImageWidth = 32;
static constexpr int kImageHeight = 32;
static constexpr int kImageBatchCount = 8;
using BiasAddGraphRunner =
std::function<void(const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* out)>;
// Runs a Tensorflow graph defined by the root scope, and fetches the result
// of 'fetch' node into the output Tensor.
void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
Tensor* output) {
tensorflow::GraphDef graph;
TF_ASSERT_OK(root.ToGraphDef(&graph));
std::unique_ptr<tensorflow::Session> session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_ASSERT_OK(session->Create(graph));
std::vector<Tensor> unfused_tensors;
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
*output = unfused_tensors[0];
}
void RunConv2DWithBias(const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* output,
int stride = 1) {
auto root = tensorflow::Scope::NewRootScope();
auto conv = ops::Conv2D(
root.WithOpName("conv"),
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
{1, stride, stride, 1}, "SAME");
auto with_bias = ops::BiasAdd(
root.WithOpName("with_bias"), conv,
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
RunAndFetch(root, "with_bias", output);
}
void RunConv2DWithBiasAndRelu(const Tensor& input_data,
const Tensor& filter_data,
const Tensor& bias_data, Tensor* output,
int stride = 1) {
auto root = tensorflow::Scope::NewRootScope();
auto conv = ops::Conv2D(
root.WithOpName("conv"),
ops::Const(root.WithOpName("input"), Input::Initializer(input_data)),
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)),
{1, stride, stride, 1}, "SAME");
auto with_bias = ops::BiasAdd(
root.WithOpName("with_bias"), conv,
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
auto with_relu = ops::Relu(root.WithOpName("with_relu"), with_bias);
RunAndFetch(root, "with_relu", output);
}
void RunMklFusedConv2DOp(const Tensor& image, const Tensor& filter,
const std::vector<Tensor>& args,
const std::vector<string>& fused_ops, Tensor* output,
int stride = 1) {
DataType dtype = DataTypeToEnum<T>::v();
int num_args = static_cast<int>(args.size());
TF_EXPECT_OK(NodeDefBuilder("fused_conv_op", "_MklFusedConv2D")
.Input(FakeInput(dtype))
.Input(FakeInput(dtype))
.Attr("num_args", num_args)
.Input(FakeInput(num_args, dtype))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(num_args, DT_UINT8))
.Attr("T", dtype)
.Attr("strides", {1, stride, stride, 1})
.Attr("padding", "SAME")
.Attr("fused_ops", fused_ops)
.Attr("_kernel", "MklOp")
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
AddInputFromArray<T>(image.shape(), image.flat<T>());
AddInputFromArray<T>(filter.shape(), filter.flat<T>());
for (const Tensor& arg : args)
AddInputFromArray<T>(arg.shape(), arg.flat<T>());
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
for (const Tensor& arg : args)
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
// Compare output to expected results
const Tensor& output_tensor = *GetOutput(0);
// Index 2 will need to be changed if the number of outputs produced
// by MklConv2D change.
const Tensor& output_meta_tensor = *GetOutput(2);
ConvMklToTF<T> conv_comp;
conv_comp.PerformConversion(dtype, output_tensor, output_meta_tensor,
output);
}
void VerifyBiasAddTensorsNear(int depth, int image_width, int image_height,
int image_batch_count, int filter_size,
int filter_count,
const BiasAddGraphRunner& run_default,
const BiasAddGraphRunner& run_fused) {
DataType dtype = DataTypeToEnum<T>::v();
Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
image.flat<T>() = image.flat<T>().setRandom();
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
filter.flat<T>() = filter.flat<T>().setRandom();
const int bias_size = filter_count;
Tensor bias(dtype, {bias_size});
bias.flat<T>() = bias.flat<T>().setRandom();
Tensor conv_2d;
Tensor fused_conv_2d;
run_default(image, filter, bias, &conv_2d);
run_fused(image, filter, bias, &fused_conv_2d);
ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
test::ExpectClose(conv_2d, fused_conv_2d);
}
// Verifies that computing Conv2D+BiasAdd in a graph is identical to
// FusedConv2D.
void VerifyConv2DWithBias(int filter_size, int filter_count,
int depth = kDepth, int image_width = kImageWidth,
int image_height = kImageHeight,
int image_batch_count = kImageBatchCount) {
const BiasAddGraphRunner run_default =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* out) {
RunConv2DWithBias(input_data, filter_data, bias_data, out);
};
const BiasAddGraphRunner run_fused =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* out) {
RunMklFusedConv2DOp(input_data, filter_data, {bias_data}, {"BiasAdd"},
out);
};
VerifyBiasAddTensorsNear(depth, image_width, image_height,
image_batch_count, filter_size, filter_count,
run_default, run_fused);
}
// Verifies that computing Conv2D+BiasAdd+Relu in a graph is identical to
// FusedConv2D.
void VerifyConv2DWithBiasAndRelu(int filter_size, int filter_count,
int depth = kDepth,
int image_width = kImageWidth,
int image_height = kImageHeight,
int image_batch_count = kImageBatchCount) {
const BiasAddGraphRunner run_default =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* out) {
RunConv2DWithBiasAndRelu(input_data, filter_data, bias_data, out);
};
const BiasAddGraphRunner run_fused =
[this](const Tensor& input_data, const Tensor& filter_data,
const Tensor& bias_data, Tensor* out) {
RunMklFusedConv2DOp(input_data, filter_data, {bias_data},
{"BiasAdd", "Relu"}, out);
};
VerifyBiasAddTensorsNear(depth, image_width, image_height,
image_batch_count, filter_size, filter_count,
run_default, run_fused);
}
};
template <typename T>
class MklFusedConv2DWithBiasOpTest : public MklFusedConv2DOpTest<T> {};
TYPED_TEST_CASE_P(MklFusedConv2DWithBiasOpTest);
// -------------------------------------------------------------------------- //
// Conv2D + BiasAdd + {Relu} //
// -------------------------------------------------------------------------- //
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolution) {
const int filter_size = 1;
const int filter_count = 12;
this->VerifyConv2DWithBias(filter_size, filter_count);
}
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolution) {
const int filter_size = 3;
const int filter_count = 12;
this->VerifyConv2DWithBias(filter_size, filter_count);
}
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, OneByOneConvolutionAndRelu) {
const int filter_size = 1;
const int filter_count = 12;
this->VerifyConv2DWithBiasAndRelu(filter_size, filter_count);
}
TYPED_TEST_P(MklFusedConv2DWithBiasOpTest, SpatialConvolutionAndRelu) {
const int filter_size = 3;
const int filter_count = 12;
this->VerifyConv2DWithBiasAndRelu(filter_size, filter_count);
}
REGISTER_TYPED_TEST_CASE_P(MklFusedConv2DWithBiasOpTest, //
OneByOneConvolution, //
SpatialConvolution, //
OneByOneConvolutionAndRelu, //
SpatialConvolutionAndRelu);
using MklFusedBiasAddDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest,
MklFusedBiasAddDataTypes);
// Testing fusion of pad and convolution
class FusedPadConvOpTest : public OpsTestBase {
......@@ -98,8 +341,8 @@ class FusedPadConvOpTest : public OpsTestBase {
// Compare output to expected results
const Tensor& first = *GetOutput(0);
const Tensor& second = *GetOutput(2);
ConvMklToTF conv_comp;
conv_comp.ConvertAndCompare<T>(dtype, first, second, expected);
ConvMklToTF<T> conv_comp;
conv_comp.ConvertAndCompare(dtype, first, second, expected);
}
};
......
......@@ -32,6 +32,33 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
REGISTER_OP("_MklFusedConv2D")
.Input("input: T")
.Input("filter: T")
.Input("args: num_args * T")
.Input("mkl_input: uint8")
.Input("mkl_filter: uint8")
.Input("mkl_args: num_args * uint8")
.Output("output: T")
.Output("filter_output: T")
.Output("mkl_output: uint8")
.Output("mkl_filter_output: uint8")
.Attr("T: {float}")
.Attr("num_args: int >= 0")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("dilations: list(int) = [1, 1, 1, 1]")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ------------------------------------ //
.Attr("epsilon: float = 0.0001")
// ---------------------------------------------------------------------- //
.SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
*NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer
is expected to create these operators.
)doc");
REGISTER_OP("_MklQuantizedMaxPool")
.Input("input: T")
.Input("min_input: float")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册