提交 6724be2b 编写于 作者: X xiaolil1 提交者: Tao Luo

INT8 Pool kernel Key Creation Optimization. (#15883)

* Optimize key creation of INT8 pool kernel to improve the peformance of ResNet-50 and MobileNet, especially for latency.
test=develop

* Optimize key creation of pool fp32 grad.
test=develop
上级 d5a888e1
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
......@@ -29,23 +30,23 @@ using mkldnn::stream;
using platform::to_void_cast;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string gethash(const memory::dims& input_dims,
const std::string& pooling_type,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const memory::data_type& dt,
const std::string& suffix) {
auto dims2str = [](const memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
dims2str(paddings) + std::to_string(dt) + pooling_type + suffix;
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const memory::dims& input_dims,
const std::string& pooling_type,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const memory::data_type& dt, const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, pooling_type);
platform::MKLDNNHandler::AppendKeyVec(&key, ksize);
platform::MKLDNNHandler::AppendKeyVec(&key, strides);
platform::MKLDNNHandler::AppendKeyVec(&key, paddings);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
return key;
}
static inline int ComputeCeiledOutput(int input_size, int kernel_size,
......@@ -114,8 +115,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type dt =
paddle::framework::ToMKLDNNDataType(input->type());
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, dt, ctx.op().Output("Out"));
const std::string key = CreateKey(ctx, src_tz, pooling_type, ksize, strides,
paddings, dt, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
......@@ -294,8 +295,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key =
gethash(diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, ctx.op().Input("Out"));
CreateKey(ctx, diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, ctx.op().Input("Out"));
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
......
......@@ -271,7 +271,6 @@ class MKLDNNHandler {
AppendKey(key, suffix);
}
protected:
static void AppendKeyDims(std::string* key,
const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
......@@ -289,6 +288,7 @@ class MKLDNNHandler {
key->append(s);
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
......@@ -302,6 +302,9 @@ class MKLDNNHandler {
mkldnn::engine engine_;
std::string key_;
bool is_reusing_;
public:
static constexpr int MaxKeyLength = 256;
};
class TransposeMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册