提交 6e7b883b 编写于 作者: M mozga-intel

Initial implementation of multiplication operator for MKLDNN

上级 fee5b24c
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/mul_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
template <typename Format = mkldnn::memory::format>
mkldnn::memory::desc type(const std::vector<int>& dims, Format&& f) {
return platform::MKLDNNMemDesc(dims, mkldnn::memory::data_type::f32, f);
}
template <typename T>
class MulMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
auto input = ctx.Input<Tensor>("X");
auto weight = ctx.Input<Tensor>("Y");
PADDLE_ENFORCE(input->dims().size() & (2 | 4),
"Input must be with 2 or 4 dimensions, i.e. NC or NCHW");
PADDLE_ENFORCE(weight->dims().size() & (2 | 4),
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
std::vector<int> w_tz = paddle::framework::vectorize2int(weight->dims());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto output = ctx.Output<Tensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string key = ctx.op().Output("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = weight->data<T>();
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd = platform::MKLDNNFwdPrimitiveDesc<mkldnn::inner_product_forward>(
mkldnn_engine, src_md, weights_md, dst_md);
dev_ctx.SetBlob(key_fc_pd, pd);
auto forward = mkldnn::inner_product_forward(*pd, src_memory,
weights_memory, dst_memory);
std::vector<mkldnn::primitive> pipeline = {forward};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
template <typename T>
class MulMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
const Tensor* w = ctx.Input<Tensor>("Y");
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
const std::string key = ctx.op().Input("Out");
const std::string key_fc_pd = key + "@mul_pd";
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
const T* out_grad_data = out_grad->data<T>();
T* input_grad_data = nullptr;
T* w_grad_data = nullptr;
if (input_grad) {
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (w_grad) {
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> w_tz = paddle::framework::vectorize2int(w->dims());
auto src_md =
src_tz.size() != 2
? type(src_tz, mkldnn::memory::format::nchw)
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
auto weights_md =
src_tz.size() != 2
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
mkldnn::memory::format::oihw)
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
platform::to_void_cast(input_data));
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine},
platform::to_void_cast(out_grad_data));
auto weight_memory = mkldnn::memory({weights_md, mkldnn_engine},
platform::to_void_cast(w_data));
auto pd =
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
dev_ctx.GetBlob(key_fc_pd));
PADDLE_ENFORCE(pd != nullptr, "Fail to find pd in device context");
if (w_grad) {
auto weights_grad_memory = mkldnn::memory(
{weights_md, mkldnn_engine}, platform::to_void_cast(w_grad_data));
auto bwd_weight_pd = platform::MKLDNNBwdPrimitiveDesc<
mkldnn::inner_product_backward_weights>(mkldnn_engine, *pd, src_md,
weights_md, dst_md);
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
if (input_grad) {
auto src_grad_memory = mkldnn::memory(
{src_md, mkldnn_engine}, platform::to_void_cast(input_grad_data));
auto bwd_data_pd =
platform::MKLDNNBwdPrimitiveDesc<mkldnn::inner_product_backward_data>(
mkldnn_engine, *pd, src_md, weights_md, dst_md);
auto bwd_data_prim = mkldnn::inner_product_backward_data(
bwd_data_pd, dst_memory, weight_memory, src_grad_memory);
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
paddle::operators::MulMKLDNNGradOpKernel<float>);
...@@ -13,8 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mul_op.h" #include "paddle/fluid/operators/mul_op.h"
#include <string>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,6 +76,22 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -71,6 +76,22 @@ class MulOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
}; };
class MulOpMaker : public framework::OpProtoAndCheckerMaker { class MulOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -100,6 +121,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -100,6 +121,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<int>( AddAttr<int>(
"y_num_col_dims", "y_num_col_dims",
R"DOC((int, default 1), The mul_op can take tensors with more than two, R"DOC((int, default 1), The mul_op can take tensors with more than two,
...@@ -154,6 +178,22 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -154,6 +178,22 @@ class MulGradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim(y_grad_name, y_dims); ctx->SetOutputDim(y_grad_name, y_dims);
} }
} }
private:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <mkldnn.h>
#include <vector> #include <vector>
#include "mkldnn/include/mkldnn.hpp"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -34,6 +33,32 @@ typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr; ...@@ -34,6 +33,32 @@ typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr; typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr; typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
template <typename Type>
void* to_void_cast(const Type* t) {
return static_cast<void*>(const_cast<Type*>(t));
}
template <class Type>
using tf_desc = typename Type::desc;
template <class Type>
using tf_pd = typename Type::primitive_desc;
template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
Args&&... args) {
auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...);
auto pd = new tf_pd<Type>(desc, e);
return std::shared_ptr<tf_pd<Type>>(pd);
}
template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
Args&&... args) {
auto desc = tf_desc<Type>(args...);
return tf_pd<Type>(desc, e, p);
}
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
mkldnn::memory::data_type data_type, mkldnn::memory::data_type data_type,
mkldnn::memory::format format) { mkldnn::memory::format format) {
......
...@@ -156,64 +156,37 @@ def fc(input, ...@@ -156,64 +156,37 @@ def fc(input,
dtype = helper.input_dtype() dtype = helper.input_dtype()
mul_results = [] mul_results = []
if use_mkldnn: for input_var, param_attr in helper.iter_inputs_and_params():
tmp = helper.create_tmp_variable(dtype) input_shape = input_var.shape
input_shape = input.shape
param_shape = [ param_shape = [
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size] ] + [size]
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
shape=param_shape, tmp = helper.create_tmp_variable(dtype)
dtype=dtype,
is_bias=False)
if bias_attr is None or bias_attr is False:
bias_attr = False
else:
bias_attr = True
helper.append_op( helper.append_op(
type="fc", type="mul",
inputs={"Input": input, inputs={"X": input_var,
"W": w}, "Y": w},
outputs={"Out": tmp}, outputs={"Out": tmp},
attrs={"use_mkldnn": use_mkldnn, attrs={
"bias_attr": bias_attr}) "x_num_col_dims": num_flatten_dims,
return helper.append_activation(tmp) "y_num_col_dims": 1,
"use_mkldnn": use_mkldnn
})
mul_results.append(tmp)
if len(mul_results) == 1:
pre_bias = mul_results[0]
else: else:
for input_var, param_attr in helper.iter_inputs_and_params(): pre_bias = helper.create_tmp_variable(dtype)
input_shape = input_var.shape helper.append_op(
param_shape = [ type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1) # add bias
] + [size] pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims)
# add activation
w = helper.create_parameter( return helper.append_activation(pre_activation)
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type="mul",
inputs={"X": input_var,
"Y": w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1,
})
mul_results.append(tmp)
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type="sum",
inputs={"X": mul_results},
outputs={"Out": pre_bias})
# add bias
pre_activation = helper.append_bias_op(
pre_bias, dim_start=num_flatten_dims)
# add activation
return helper.append_activation(pre_activation)
def embedding(input, def embedding(input,
...@@ -3688,8 +3661,8 @@ def label_smooth(label, ...@@ -3688,8 +3661,8 @@ def label_smooth(label,
name=None): name=None):
""" """
Label smoothing is a mechanism to regularize the classifier layer and is Label smoothing is a mechanism to regularize the classifier layer and is
called label-smoothing regularization (LSR). called label-smoothing regularization (LSR).
Label smoothing is proposed to encourage the model to be less confident, Label smoothing is proposed to encourage the model to be less confident,
since optimizing the log-likelihood of the correct label directly may since optimizing the log-likelihood of the correct label directly may
cause overfitting and reduce the ability of the model to adapt. Label cause overfitting and reduce the ability of the model to adapt. Label
...@@ -3713,10 +3686,10 @@ def label_smooth(label, ...@@ -3713,10 +3686,10 @@ def label_smooth(label,
prior_dist(Variable): The prior distribution to be used to smooth prior_dist(Variable): The prior distribution to be used to smooth
labels. If not provided, an uniform distribution labels. If not provided, an uniform distribution
is used. The shape of :attr:`prior_dist` should is used. The shape of :attr:`prior_dist` should
be :math:`(1, class\_num)`. be :math:`(1, class\_num)`.
epsilon(float): The weight used to mix up the original ground-truth epsilon(float): The weight used to mix up the original ground-truth
distribution and the fixed distribution. distribution and the fixed distribution.
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32,
float_64, int etc. float_64, int etc.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2
class TestMKLDNNMulOp(TestMulOp):
def init_op_test(self):
super(TestMKLDNNMulOp, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNMulOp2(TestMulOp2):
def init_op_test(self):
super(TestMKLDNNMulOp2, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNFP16MulOp1(TestFP16MulOp1):
def init_op_test(self):
super(TestMKLDNNFP16MulOp1, self).setUp()
self.attrs = {"use_mkldnn": True}
class TestMKLDNNFP16MulOp2(TestFP16MulOp2):
def init_op_test(self):
super(TestMKLDNNFP16MulOp2, self).setUp()
self.attrs = {"use_mkldnn": True}
if __name__ == "__main__":
unittest.main()
...@@ -21,10 +21,12 @@ from op_test import OpTest ...@@ -21,10 +21,12 @@ from op_test import OpTest
class TestMulOp(OpTest): class TestMulOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.use_mkldnn = False
self.inputs = { self.inputs = {
'X': np.random.random((32, 84)).astype("float32"), 'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32") 'Y': np.random.random((84, 100)).astype("float32")
} }
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self): def test_check_output(self):
...@@ -45,11 +47,16 @@ class TestMulOp(OpTest): ...@@ -45,11 +47,16 @@ class TestMulOp(OpTest):
class TestMulOp2(OpTest): class TestMulOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.use_mkldnn = False
self.inputs = { self.inputs = {
'X': np.random.random((15, 4, 12, 10)).astype("float32"), 'X': np.random.random((15, 4, 12, 10)).astype("float32"),
'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32") 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32")
} }
self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2} self.attrs = {
'x_num_col_dims': 2,
'y_num_col_dims': 2,
'use_mkldnn': self.use_mkldnn
}
result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10), result = np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10),
self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9)) self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9))
result = result.reshape(15, 4, 8, 2, 9) result = result.reshape(15, 4, 8, 2, 9)
...@@ -73,9 +80,11 @@ class TestMulOp2(OpTest): ...@@ -73,9 +80,11 @@ class TestMulOp2(OpTest):
class TestFP16MulOp1(OpTest): class TestFP16MulOp1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.use_mkldnn = False
x = np.random.random((32, 84)).astype("float16") x = np.random.random((32, 84)).astype("float16")
y = np.random.random((84, 100)).astype("float16") y = np.random.random((84, 100)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.attrs = {'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': np.dot(x, y)} self.outputs = {'Out': np.dot(x, y)}
def test_check_output(self): def test_check_output(self):
...@@ -88,12 +97,14 @@ class TestFP16MulOp1(OpTest): ...@@ -88,12 +97,14 @@ class TestFP16MulOp1(OpTest):
class TestFP16MulOp2(OpTest): class TestFP16MulOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "mul" self.op_type = "mul"
self.use_mkldnn = False
x = np.random.random((15, 4, 12, 10)).astype("float16") x = np.random.random((15, 4, 12, 10)).astype("float16")
y = np.random.random((4, 30, 8, 2, 9)).astype("float16") y = np.random.random((4, 30, 8, 2, 9)).astype("float16")
self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)} self.inputs = {'X': x.view(np.uint16), 'Y': y.view(np.uint16)}
self.attrs = { self.attrs = {
'x_num_col_dims': 2, 'x_num_col_dims': 2,
'y_num_col_dims': 2, 'y_num_col_dims': 2,
'use_mkldnn': self.use_mkldnn
} }
result = np.dot( result = np.dot(
x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9)) x.reshape(15 * 4, 12 * 10), y.reshape(4 * 30, 8 * 2 * 9))
......
...@@ -62,7 +62,8 @@ class TestOperator(unittest.TestCase): ...@@ -62,7 +62,8 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output_names, ["Out"]) self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual( self.assertEqual(
set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"])) set(mul_op.attr_names),
set(["x_num_col_dims", "y_num_col_dims", "use_mkldnn"]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册