You need to sign in or sign up before continuing.
提交 a64b312e 编写于 作者: K Krzysztof Binias

Correcting for PR comments

上级 4466f0be
...@@ -84,10 +84,6 @@ class OperatorBase { ...@@ -84,10 +84,6 @@ class OperatorBase {
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
} }
inline bool HasAttr(const std::string& name) const {
return attrs_.count(name) != 0;
}
/// if scope is not null, also show dimensions of arguments /// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const Scope* scope) const; virtual std::string DebugStringEx(const Scope* scope) const;
...@@ -199,10 +195,6 @@ class ExecutionContext { ...@@ -199,10 +195,6 @@ class ExecutionContext {
return op_.Attr<T>(name); return op_.Attr<T>(name);
} }
inline bool HasAttr(const std::string& name) const {
return op_.HasAttr(name);
}
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "mkldnn_activation_op.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
namespace paddle { namespace paddle {
...@@ -183,10 +184,10 @@ namespace ops = paddle::operators; ...@@ -183,10 +184,10 @@ namespace ops = paddle::operators;
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \ act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>); ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor) \ __macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor) \ __macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor) \ __macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor); __macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL); FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
...@@ -100,13 +100,6 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -100,13 +100,6 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Relu Activation Operator. Relu Activation Operator.
...@@ -163,13 +156,6 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -163,13 +156,6 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Tanh Activation Operator. Tanh Activation Operator.
...@@ -226,13 +212,6 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -226,13 +212,6 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Sqrt Activation Operator. Sqrt Activation Operator.
...@@ -251,13 +230,6 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -251,13 +230,6 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Abs Activation Operator. Abs Activation Operator.
......
...@@ -37,10 +37,6 @@ class ActivationHelper { ...@@ -37,10 +37,6 @@ class ActivationHelper {
} }
#endif #endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
if (ctx.HasAttr("data_format")) {
std::string data_format = ctx.Attr<std::string>("data_format");
layout = framework::StringToDataLayout(data_format);
}
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library); ctx.GetPlace(), layout, library);
...@@ -76,27 +72,6 @@ class ActivationKernel ...@@ -76,27 +72,6 @@ class ActivationKernel
} }
}; };
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(!context.HasAttr("X"),
"Cannot find input tensor X, variable name = %s",
context.op().Input("X"));
PADDLE_ENFORCE(!context.HasAttr("Out"),
"Cannot find output tensor Out, variable name = %s",
context.op().Output("Out"));
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationGradKernel class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...@@ -125,21 +100,6 @@ class ActivationGradKernel ...@@ -125,21 +100,6 @@ class ActivationGradKernel
} }
}; };
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename T> template <typename T>
struct BaseActivationFunctor { struct BaseActivationFunctor {
using ELEMENT_TYPE = T; using ELEMENT_TYPE = T;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Input<framework::Tensor>("X") != nullptr,
"Cannot get input tensor X, variable name = %s",
context.op().Input("X"));
PADDLE_ENFORCE(context.Output<framework::Tensor>("Out") != nullptr,
"Cannot find output tensor Out, variable name = %s",
context.op().Output("Out"));
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
} // namespace operators
} // namespace paddle
...@@ -42,7 +42,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims, ...@@ -42,7 +42,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
} }
inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
if (!ctx.HasAttr("use_mkldnn")) return false;
bool use_mkldnn = ctx.Attr<bool>("use_mkldnn"); bool use_mkldnn = ctx.Attr<bool>("use_mkldnn");
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
} }
......
...@@ -215,8 +215,7 @@ class OpTest(unittest.TestCase): ...@@ -215,8 +215,7 @@ class OpTest(unittest.TestCase):
'''Fix random seeds to remove randomness from tests''' '''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state() cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate() cls._py_rand_state = random.getstate()
cls.use_mkldnn = False
cls.data_format = 'AnyLayout'
np.random.seed(123) np.random.seed(123)
random.seed(124) random.seed(124)
...@@ -341,14 +340,7 @@ class OpTest(unittest.TestCase): ...@@ -341,14 +340,7 @@ class OpTest(unittest.TestCase):
"Output (" + out_name + "Output (" + out_name +
") has different lod at " + str(place)) ") has different lod at " + str(place))
def fill_attrs(self):
attrs = self.attrs if hasattr(self, "attrs") else dict()
attrs["use_mkldnn"] = self.use_mkldnn
attrs["data_format"] = self.data_format
return attrs
def check_output(self, atol=1e-5): def check_output(self, atol=1e-5):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
...@@ -356,7 +348,6 @@ class OpTest(unittest.TestCase): ...@@ -356,7 +348,6 @@ class OpTest(unittest.TestCase):
self.check_output_with_place(place, atol) self.check_output_with_place(place, atol)
def check_output_customized(self, checker): def check_output_customized(self, checker):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
...@@ -392,7 +383,6 @@ class OpTest(unittest.TestCase): ...@@ -392,7 +383,6 @@ class OpTest(unittest.TestCase):
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None):
self.attrs = self.fill_attrs()
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type): if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
......
...@@ -515,7 +515,7 @@ class TestMKLDNNRelu(OpTest): ...@@ -515,7 +515,7 @@ class TestMKLDNNRelu(OpTest):
x[np.abs(x) < 0.005] = 0.02 x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': np.maximum(self.inputs['X'], 0)} self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
self.use_mkldnn = True self.attrs = {"use_mkldnn": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -531,7 +531,7 @@ class TestMKLDNNTanh(OpTest): ...@@ -531,7 +531,7 @@ class TestMKLDNNTanh(OpTest):
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
} }
self.outputs = {'Out': np.tanh(self.inputs['X'])} self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.use_mkldnn = True self.attrs = {"use_mkldnn": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -547,7 +547,7 @@ class TestMKLDNNSqrt(OpTest): ...@@ -547,7 +547,7 @@ class TestMKLDNNSqrt(OpTest):
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32") 'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
} }
self.outputs = {'Out': np.sqrt(self.inputs['X'])} self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.use_mkldnn = True self.attrs = {"use_mkldnn": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -564,7 +564,7 @@ class TestMKLDNNAbs(OpTest): ...@@ -564,7 +564,7 @@ class TestMKLDNNAbs(OpTest):
x[np.abs(x) < 0.005] = 0.02 x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': np.abs(self.inputs['X'])} self.outputs = {'Out': np.abs(self.inputs['X'])}
self.use_mkldnn = True self.attrs = {"use_mkldnn": True}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册