提交 4466f0be 编写于 作者: K Krzysztof Binias

MKLDNN Relu Tanh Sqrt Abs activations added

上级 c83dd9b4
...@@ -84,6 +84,10 @@ class OperatorBase { ...@@ -84,6 +84,10 @@ class OperatorBase {
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
} }
inline bool HasAttr(const std::string& name) const {
return attrs_.count(name) != 0;
}
/// if scope is not null, also show dimensions of arguments /// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const; virtual std::string DebugStringEx(const Scope* scope) const;
...@@ -195,6 +199,10 @@ class ExecutionContext { ...@@ -195,6 +199,10 @@ class ExecutionContext {
return op_.Attr<T>(name); return op_.Attr<T>(name);
} }
inline bool HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
} }
......
...@@ -153,7 +153,12 @@ function(op_library TARGET) ...@@ -153,7 +153,12 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for MKLDNN # pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
endif() endif()
# pybind USE_OP # pybind USE_OP
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
namespace {
template <typename T, typename ExecContext>
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers
const auto *src = ctx.template Input<Tensor>("X");
const auto *src_data = src->template data<T>();
auto *dst = ctx.template Output<Tensor>("Out");
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim
PADDLE_ENFORCE(src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW");
std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description
// TODO(kbinias-intel): support more formats
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);
auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
// save prim desc into global device context to be referred in backward path
const std::string key = ctx.op().Output("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
template <typename T, typename ExecContext>
void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers
const auto *x = ctx.template Input<Tensor>("X");
const auto *src = x->template data<T>();
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>();
auto *dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
// get memory dim
std::vector<int> src_tz = framework::vectorize2int(x->dims());
// create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
auto diff_dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);
auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
// retrieve eltwise primitive desc from device context
const std::string key = ctx.op().Input("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd);
PADDLE_ENFORCE(forward_pd != nullptr,
"Fail to find eltwise_pd in device context");
auto *p_forward_pd =
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *p_forward_pd);
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
diff_dst_memory, diff_src_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // anonymous namespace
template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_forward<T>(ctx, algorithm);
}
};
template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};
template <typename T>
using ReluMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using TanhMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;
template <typename T>
using AbsMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;
template <typename T>
using ReluMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using TanhMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;
template <typename T>
using AbsMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor) \
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor) \
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor) \
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -25,6 +25,11 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -25,6 +25,11 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
}; };
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
...@@ -34,6 +39,11 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -34,6 +39,11 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
}; };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -87,6 +97,16 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -87,6 +97,16 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator"); AddInput("X", "Input of Relu operator");
AddOutput("Out", "Output of Relu operator"); AddOutput("Out", "Output of Relu operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Relu Activation Operator. Relu Activation Operator.
...@@ -140,6 +160,16 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -140,6 +160,16 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator"); AddInput("X", "Input of Tanh operator");
AddOutput("Out", "Output of Tanh operator"); AddOutput("Out", "Output of Tanh operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Tanh Activation Operator. Tanh Activation Operator.
...@@ -193,6 +223,16 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -193,6 +223,16 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator"); AddInput("X", "Input of Sqrt operator");
AddOutput("Out", "Output of Sqrt operator"); AddOutput("Out", "Output of Sqrt operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Sqrt Activation Operator. Sqrt Activation Operator.
...@@ -208,6 +248,16 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -208,6 +248,16 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) { : framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator"); AddInput("X", "Input of Abs operator");
AddOutput("Out", "Output of Abs operator"); AddOutput("Out", "Output of Abs operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Abs Activation Operator. Abs Activation Operator.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -17,9 +17,36 @@ limitations under the License. */ ...@@ -17,9 +17,36 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ActivationHelper {
public:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) const {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.HasAttr("data_format")) {
std::string data_format = ctx.Attr<std::string>("data_format");
layout = framework::StringToDataLayout(data_format);
}
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationKernel class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...@@ -49,6 +76,27 @@ class ActivationKernel ...@@ -49,6 +76,27 @@ class ActivationKernel
} }
}; };
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(!context.HasAttr("X"),
"Cannot find input tensor X, variable name = %s",
context.op().Input("X"));
PADDLE_ENFORCE(!context.HasAttr("Out"),
"Cannot find output tensor Out, variable name = %s",
context.op().Output("Out"));
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationGradKernel class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...@@ -77,6 +125,21 @@ class ActivationGradKernel ...@@ -77,6 +125,21 @@ class ActivationGradKernel
} }
}; };
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename T> template <typename T>
struct BaseActivationFunctor { struct BaseActivationFunctor {
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
......
...@@ -42,6 +42,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, ...@@ -42,6 +42,7 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
} }
inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
if (!ctx.HasAttr("use_mkldnn")) return false;
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn"); bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
} }
......
...@@ -403,6 +403,8 @@ class LayerHelper(object): ...@@ -403,6 +403,8 @@ class LayerHelper(object):
if 'use_mkldnn' in self.kwargs: if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn') act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
act_type = act.pop('type') act_type = act.pop('type')
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
self.append_op( self.append_op(
type=act_type, type=act_type,
inputs={"X": [input_var]}, inputs={"X": [input_var]},
......
...@@ -215,7 +215,8 @@ class OpTest(unittest.TestCase): ...@@ -215,7 +215,8 @@ class OpTest(unittest.TestCase):
'''Fix random seeds to remove randomness from tests''' '''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state() cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate() cls._py_rand_state = random.getstate()
cls.use_mkldnn = False
cls.data_format = 'AnyLayout'
np.random.seed(123) np.random.seed(123)
random.seed(124) random.seed(124)
...@@ -340,7 +341,14 @@ class OpTest(unittest.TestCase): ...@@ -340,7 +341,14 @@ class OpTest(unittest.TestCase):
"Output (" + out_name + "Output (" + out_name +
") has different lod at " + str(place)) ") has different lod at " + str(place))
def fill_attrs(self):
attrs = self.attrs if hasattr(self, "attrs") else dict()
attrs["use_mkldnn"] = self.use_mkldnn
attrs["data_format"] = self.data_format
return attrs
def check_output(self, atol=1e-5): def check_output(self, atol=1e-5):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
...@@ -348,6 +356,7 @@ class OpTest(unittest.TestCase): ...@@ -348,6 +356,7 @@ class OpTest(unittest.TestCase):
self.check_output_with_place(place, atol) self.check_output_with_place(place, atol)
def check_output_customized(self, checker): def check_output_customized(self, checker):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
...@@ -383,6 +392,7 @@ class OpTest(unittest.TestCase): ...@@ -383,6 +392,7 @@ class OpTest(unittest.TestCase):
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
......
...@@ -506,5 +506,72 @@ class TestSwish(OpTest): ...@@ -506,5 +506,72 @@ class TestSwish(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.008) self.check_grad(['X'], 'Out', max_relative_error=0.008)
#--------------------test MKLDNN--------------------
class TestMKLDNNRelu(OpTest):
def setUp(self):
self.op_type = "relu"
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
self.use_mkldnn = True
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestMKLDNNTanh(OpTest):
def setUp(self):
self.op_type = "tanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.use_mkldnn = True
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestMKLDNNSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.use_mkldnn = True
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestMKLDNNAbs(OpTest):
def setUp(self):
self.op_type = "abs"
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.abs(self.inputs['X'])}
self.use_mkldnn = True
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册