提交 738069e4 编写于 作者: M Michal Gallus

Refactor MKL-DNN Concat

test=develop
上级 208f9125
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -26,25 +27,6 @@ using mkldnn::concat; ...@@ -26,25 +27,6 @@ using mkldnn::concat;
using mkldnn::stream; using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
// Generate keys for storing/retriving primitives for this operator
// TODO(jczaja): Make hashing function more optimial
static std::string gethash(const memory::dims& input_dims,
const std::string& pooling_type,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& suffix) {
auto dims2str = [](const memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
dims2str(paddings) + pooling_type + suffix;
}
static void EnforceLayouts(const std::vector<const Tensor*> inputs) { static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) { for (auto* input : inputs) {
const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN; const bool is_layout_correct = input->layout() == DataLayout::kMKLDNN;
...@@ -56,7 +38,7 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) { ...@@ -56,7 +38,7 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
} }
static memory::primitive_desc CreateMemPrimDesc( static memory::primitive_desc CreateMemPrimDesc(
const framework::Tensor& input, const mkldnn::engine& engine) { const Tensor& input, const mkldnn::engine& engine) {
constexpr auto data_type = mkldnn::memory::f32; constexpr auto data_type = mkldnn::memory::f32;
const auto dims = paddle::framework::vectorize2int(input.dims()); const auto dims = paddle::framework::vectorize2int(input.dims());
const auto format = input.format(); const auto format = input.format();
...@@ -65,6 +47,11 @@ static memory::primitive_desc CreateMemPrimDesc( ...@@ -65,6 +47,11 @@ static memory::primitive_desc CreateMemPrimDesc(
return mem_prim_desc; return mem_prim_desc;
} }
static mkldnn::memory::format GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (memory::format)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -73,139 +60,87 @@ static platform::CPUPlace GetCpuPlace( ...@@ -73,139 +60,87 @@ static platform::CPUPlace GetCpuPlace(
return boost::get<platform::CPUPlace>(place); return boost::get<platform::CPUPlace>(place);
} }
template <typename T> static const mkldnn::engine& GetMKLDNNEngine(
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { const paddle::framework::ExecutionContext& ctx) {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto place = GetCpuPlace(ctx);
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); return dev_ctx.GetEngine();
}
auto multi_input = ctx.MultiInput<framework::Tensor>("X"); template <typename T>
framework::Tensor* output = ctx.Output<framework::Tensor>("Out"); class ConcatPrimitiveFactory {
int64_t concat_axis = static_cast<int64_t>(ctx.Attr<int>("axis")); public:
concat::primitive_desc CreateConcatPrimDescriptor(
const std::vector<const Tensor*> multi_input, Tensor* output,
int concat_axis, const mkldnn::engine& mkldnn_engine) {
CreateSourcesDescriptors(multi_input, mkldnn_engine);
auto dst_desc = CreateDstMemDescriptor(output);
return concat::primitive_desc(dst_desc, concat_axis, srcs_pd);
}
EnforceLayouts(multi_input); concat CreateConcatPrimitive(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place) {
CreateSourcePrimitiveAts();
auto dst_mem = CreateDstMemory(concat_pd, output, place);
return concat(concat_pd, inputs, dst_mem);
}
std::vector<memory::primitive_desc> srcs_pd; private:
std::vector<memory> srcs; memory::desc CreateDstMemDescriptor(Tensor* output) {
auto dst_dims = paddle::framework::vectorize2int(output->dims());
return memory::desc(dst_dims, platform::MKLDNNGetDataType<T>(),
memory::format::any);
}
mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd,
Tensor* output, platform::CPUPlace place) {
return memory(concat_pd.dst_primitive_desc(),
output->mutable_data<T>(place));
}
void CreateSourcesDescriptors(const std::vector<const Tensor*> multi_input,
const mkldnn::engine& mkldnn_engine) {
for (size_t i = 0; i < multi_input.size(); i++) { for (size_t i = 0; i < multi_input.size(); i++) {
auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine);
srcs_pd.push_back(mem_prim_desc); srcs_pd.push_back(mem_prim_desc);
srcs.push_back(memory(mem_prim_desc, to_void_cast(multi_input[i]->data<T>()))); srcs.push_back(memory(mem_prim_desc,
to_void_cast(multi_input[i]->data<T>())));
}
} }
auto dst_dims = paddle::framework::vectorize2int(output->dims());
auto dst_desc = memory::desc(dst_dims, mkldnn::memory::f32, memory::format::any);
auto concat_pd = concat::primitive_desc(dst_desc, static_cast<int>(concat_axis), srcs_pd);
auto dst_mem = memory(concat_pd.dst_primitive_desc(), output->mutable_data<T>(place));
std::vector<primitive::at> inputs; //= {srcs}; void CreateSourcePrimitiveAts() {
inputs.reserve(srcs.size()); inputs.reserve(srcs.size());
for (size_t i = 0; i < srcs.size(); i++) { for (size_t i = 0; i < srcs.size(); i++) {
inputs.push_back(srcs[i]); inputs.push_back(srcs[i]);
} }
auto concat_prim = concat(concat_pd, inputs, dst_mem);
std::vector<primitive> pipeline;
pipeline.push_back(concat_prim);
stream(stream::kind::eager).submit(pipeline).wait(); // TODO(mgallus): When this is not workin' split into decl and def
/*
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto input_format = input->format();
memory::format output_format{memory::format::format_undef};
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
paddings, ctx.op().Output("Out"));
const std::string key_pool_p = key + "@pool_p";
const std::string key_pool_pd = key + "@pool_pd";
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
const std::string key_pool_workspace_memory =
key + "@pool_workspace_memory";
auto pool_p =
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
if (pool_p == nullptr) {
const std::vector<int>& padding_left_top(paddings);
std::vector<int> padding_right_bottom(paddings);
bool ceil_mode = ctx.Attr<bool>("ceil_mode");
if (ceil_mode) {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
padding_right_bottom);
} }
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
mkldnn::memory::format::any);
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd = private:
CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top, std::vector<memory::primitive_desc> srcs_pd;
padding_right_bottom, ksize, pooling_type, std::vector<memory> srcs;
mkldnn_engine, ceil_mode, is_test); std::vector<primitive::at> inputs;
};
// save pool_pd into global device context to be referred in backward path
if (!is_test) dev_ctx.SetBlob(key_pool_pd, pool_pd);
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
to_void_cast<T>(input_data));
auto dst_memory =
std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
if (is_test) {
pool_p = std::make_shared<pooling_forward>(*pool_pd, *src_memory,
*dst_memory);
} else {
std::shared_ptr<mkldnn::memory> workspace_memory =
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
// save pool_workspace_memory to be referred in backward path template <typename T>
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory); class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto place = GetCpuPlace(ctx);
const auto& mkldnn_engine = GetMKLDNNEngine(ctx);
pool_p = std::make_shared<pooling_forward>( auto multi_input = ctx.MultiInput<Tensor>("X");
*pool_pd, *src_memory, *dst_memory, *workspace_memory); EnforceLayouts(multi_input);
} Tensor* output = ctx.Output<Tensor>("Out");
int64_t concat_axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
dev_ctx.SetBlob(key_pool_p, pool_p); ConcatPrimitiveFactory<T> prim_creator;
auto concat_pd = prim_creator.CreateConcatPrimDescriptor(multi_input,
output_format = output, static_cast<int>(concat_axis), mkldnn_engine);
(memory::format)dst_memory->get_primitive_desc().desc().data.format; auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place);
} else { stream(stream::kind::eager).submit({concat}).wait();
// Primitives already exist
auto pool_src_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
PADDLE_ENFORCE(pool_src_memory_p != nullptr,
"Fail to find pooling src mem_p in device context");
auto pool_dst_memory_p =
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
"Fail to find pooling dst mem_p in device context");
pool_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
pool_dst_memory_p->set_data_handle(output_data);
output_format = (memory::format)pool_dst_memory_p->get_primitive_desc()
.desc()
.data.format;
}
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
stream(stream::kind::eager).submit(pipeline).wait();
*/
output->mutable_data(place);
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format((memory::format)dst_mem.get_primitive_desc().desc() output->set_format(GetDstMemFormat(concat_pd));
.data.format);
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册