提交 cd1ba129 编写于 作者: B Bhavani Subramanian

Code cleanup

上级 ad2702d2
......@@ -1672,24 +1672,23 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
DCHECK(n);
Node* filter_node = nullptr;
TF_CHECK_OK(n->input_node(0, &filter_node));
bool narrow_range = false;
int axis = -1;
string mode_string;
string round_mode_string;
DataType type;
TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
TryGetNodeAttr(n->def(), "axis", &axis);
TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
if (narrow_range) {
bool narrow_range;
if (TryGetNodeAttr(n->def(), "narrow_range", &narrow_range) &&
narrow_range) {
VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
<< "This case is not optimized by Intel MKL, "
<< "thus using Eigen op for Quantize op ";
return false;
}
if (axis != -1) {
int axis;
if (TryGetNodeAttr(n->def(), "axis", &axis) && axis != -1) {
VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
<< "per slice quantization."
<< "This case is not optimized by Intel MKL, "
......
......@@ -70,10 +70,6 @@ struct MklBatchMatMulHelper {
if (ndims_rhs < ndims_out) {
ExpandInputDimsToOutputShape(rhs_shape, out_shape, &rhs_dims);
}
using dim = dnnl::memory::dim;
dim m; // Number of rows in x
dim k; // Number of columns in x
dim n; // Number of columns in y
auto lhs_strides = CalculateTFStrides(lhs_dims);
auto rhs_strides = CalculateTFStrides(rhs_dims);
auto out_strides = CalculateTFStrides(out_dims);
......@@ -81,8 +77,7 @@ struct MklBatchMatMulHelper {
if (adj_x) {
int m_idx = ndims_out - 1;
int k_idx = ndims_out - 2;
m = lhs_dims[m_idx];
k = lhs_dims[k_idx];
memory::dim m = lhs_dims[m_idx]; // number of rows in x
std::swap(lhs_dims[m_idx], lhs_dims[k_idx]);
lhs_strides[m_idx] = m;
lhs_strides[k_idx] = 1;
......@@ -91,8 +86,7 @@ struct MklBatchMatMulHelper {
if (adj_y) {
int k_idx = ndims_out - 1;
int n_idx = ndims_out - 2;
k = rhs_dims[k_idx];
n = rhs_dims[n_idx];
memory::dim k = rhs_dims[k_idx]; // number of columns in x
std::swap(rhs_dims[k_idx], rhs_dims[n_idx]);
rhs_strides[k_idx] = k;
rhs_strides[n_idx] = 1;
......
......@@ -49,7 +49,9 @@ template <typename Device, typename Tlhs, typename Trhs, typename Toutput,
class BatchMatMulMkl : public OpKernel {
public:
explicit BatchMatMulMkl(OpKernelConstruction* context) : OpKernel(context) {
if (context && context->HasAttr("transpose_a")) {
if (!context) return;
if (context->HasAttr("transpose_a")) {
// This is needed for using BatchMatMulMkl as the super class of
// MklMatMulOp (below) whose context has a transpose_a attribute which is
// effectively the same as adj_x_
......@@ -58,7 +60,7 @@ class BatchMatMulMkl : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
}
if (context && context->HasAttr("transpose_b")) {
if (context->HasAttr("transpose_b")) {
// This is needed for using BatchMatMulMkl as the super class of
// MklMatMulOp (below) whose context has a transpose_b attribute which is
// effectively the same as adj_y_
......@@ -294,6 +296,10 @@ class FusedBatchMatMulMkl
}
if (this->fused_ops_.size() > 1 && this->fused_ops_.at(1) == "Add") {
auto add_shape = ctx->input(3).shape();
OP_REQUIRES(ctx, add_shape.dims() == 4,
absl::InvalidArgumentError(absl::StrCat(
"Add fusion expects add shape to have 4 dims, but got ",
add_shape.dims())));
memory::dims add_dims = {add_shape.dim_size(0), add_shape.dim_size(1),
add_shape.dim_size(2), add_shape.dim_size(3)};
params.post_op_params.push_back(
......
......@@ -481,7 +481,7 @@ class MklConcatOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::kind::cpu, 0);
OpInputList input_tensors;
OpInputList input_tensors(context, 0, 0);
GetMklInputList(context, "values", &input_tensors);
const int N = input_tensors.size();
// Get Tensor shapes.
......@@ -563,7 +563,8 @@ class MklConcatOp : public OpKernel {
// That is due to an incorrect output results in DNNL 1.2 path.
if (expected_dims == 2) invoke_eigen = true;
OpInputList input_mins, input_maxes;
OpInputList input_mins(context, 0, 0);
OpInputList input_maxes(context, 0, 0);
bool quantized_input =
std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
if (quantized_input) {
......
......@@ -568,11 +568,17 @@ class MklDnnConvUtil {
OP_REQUIRES(context_, input_tf_shape.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
input_tf_shape.DebugString()));
OP_REQUIRES(context_, filter_tf_shape.dims() == 4,
errors::InvalidArgument("filter must be 4-dimensional",
filter_tf_shape.DebugString()));
} else {
// Conv3D
OP_REQUIRES(context_, input_tf_shape.dims() == 5,
errors::InvalidArgument("input must be 5-dimensional",
input_tf_shape.DebugString()));
OP_REQUIRES(context_, filter_tf_shape.dims() == 5,
errors::InvalidArgument("filter must be 5-dimensional",
filter_tf_shape.DebugString()));
}
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
......
......@@ -200,7 +200,7 @@ class MklEinsum : public OpKernel {
virtual ~MklEinsum() {}
void Compute(OpKernelContext* ctx) override {
OpInputList inputs;
OpInputList inputs(ctx, 0, 0);
OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs));
if (std::is_same<T, float>::value) {
......
......@@ -651,7 +651,8 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
std::vector<std::unordered_map<int, memory>> net_args;
BatchNormBwdContext()
: src_mem(nullptr),
: flags(0),
src_mem(nullptr),
mean_mem(nullptr),
variance_mem(nullptr),
diff_dst_mem(nullptr),
......
......@@ -131,7 +131,6 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
memory::format_tag ws_fmt;
// Workspace shape.
memory::dims ws_dims;
memory::data_type ws_dt;
size_t ws_size;
......@@ -161,6 +160,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
: src_fmt(memory::format_tag::any),
dst_fmt(memory::format_tag::any),
ws_fmt(memory::format_tag::any),
ws_dt(memory::data_type::u8),
ws_size(0),
ws_mem(nullptr),
src_mem(nullptr),
dst_mem(nullptr),
......@@ -284,7 +285,6 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
memory::format_tag ws_fmt;
// Workspace attribute.
dnnl::memory::dims ws_dims;
dnnl::memory::data_type ws_dt;
// oneDNN memory.
......@@ -315,6 +315,7 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
: diff_src_fmt(memory::format_tag::any),
diff_dst_fmt(memory::format_tag::any),
ws_fmt(memory::format_tag::any),
ws_dt(memory::data_type::u8),
ws_mem(nullptr),
diff_src_mem(nullptr),
diff_dst_mem(nullptr),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册