未验证 提交 7260e3a4 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #9214 from jczaja/prv-softmax-mkldnn-operator-PR

Softmax MKLDNN FLUID operator
......@@ -78,7 +78,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
for (int64_t i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
int64_t index = i * class_num + label_data[i];
dx_data[index] = -dy_data[i] / x_data[index];
dx_data[index] = math::TolerableValue<T>()(-dy_data[i] / x_data[index]);
}
}
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include <iostream>
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory; // Note: paddle has also "memory" namespace
using mkldnn::primitive;
using mkldnn::softmax_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
template <typename T>
class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
PADDLE_ENFORCE(input->dims().size() == 2UL,
"The input of softmax op must be a 2D matrix.");
const T* input_data = input->data<T>();
// allocate memory for output
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// MKL-DNN does support softmax over selected axis. Having 2D Tensor,
// we will make normalization after final eg. axis: 1
PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])),
"Softmax input and output dimensions should match");
// Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Currently only supports NC data format
// TODO(jczaja-intel): support more formats
auto softmax_md =
MKLDNNMemDesc({softmax_tz}, memory::f32, memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/);
// create memory primitives
auto softmax_src_memory =
memory({softmax_md, mkldnn_engine}, (void*)input_data);
auto softmax_dst_memory =
memory({softmax_md, mkldnn_engine}, (void*)output_data);
auto softmax_prim_desc =
softmax_forward::primitive_desc(softmax_desc, mkldnn_engine);
auto softmax = softmax_forward(softmax_prim_desc, softmax_src_memory,
softmax_dst_memory);
std::vector<primitive> pipeline{softmax};
stream(stream::kind::eager).submit(pipeline).wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNKernel<float>);
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/softmax_op.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -51,13 +54,18 @@ class SoftmaxOp : public framework::OperatorWithKernel {
if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN;
}
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftmaxOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -77,6 +85,9 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Softmax Operator.
......
......@@ -399,6 +399,9 @@ class LayerHelper(object):
if isinstance(act, basestring):
act = {'type': act}
tmp = self.create_tmp_variable(dtype=input_var.dtype)
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
act_type = act.pop('type')
self.append_op(
type=act_type,
......
......@@ -82,6 +82,7 @@ def fc(input,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
use_mkldnn=False,
act=None,
name=None):
"""
......@@ -163,8 +164,11 @@ def fc(input,
inputs={"X": input_var,
"Y": w},
outputs={"Out": tmp},
attrs={"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1})
attrs={
"x_num_col_dims": num_flatten_dims,
"y_num_col_dims": 1,
'use_mkldnn': use_mkldnn
})
mul_results.append(tmp)
# sum
......
......@@ -27,15 +27,20 @@ def stable_softmax(x):
class TestSoftmaxOp(OpTest):
def setUp(self):
self.use_mkldnn = False
self.op_type = "softmax"
self.use_cudnn = False
self.init_op_type()
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
self.attrs = {'use_cudnn': self.use_cudnn, }
self.attrs = {
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
}
def init_op_type(self):
pass
......@@ -61,5 +66,10 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp):
self.use_cudnn = True
class TestMKLDNN(TestSoftmaxOp):
def init_op_type(self):
self.use_mkldnn = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册