提交 1658958f 编写于 作者: K Krzysztof Binias

Reusing converted weights

上级 a5576083
...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT const std::vector<mkldnn::primitive>& pipeline,
bool is_test = false) { // NOLINT
auto user_weights_pd = user_weights_memory_p->get_primitive_desc(); auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc(); auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd, return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p", user_weights_memory_p, "@weights_mem_p",
pipeline); pipeline, is_test);
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline); user_weights_memory_p, pipeline, is_test);
auto dst_memory_p = auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
......
...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
} }
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
......
...@@ -191,8 +191,8 @@ class MKLDNNHandler { ...@@ -191,8 +191,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& mpd, // NOLINT mkldnn::memory::primitive_desc& mpd, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix, const std::string& suffix, const std::vector<mkldnn::primitive>& pipeline,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT bool is_test = false) { // NOLINT
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix; auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p"; auto key_reorder_p = key_ + suffix + "reorder_p";
...@@ -213,7 +213,7 @@ class MKLDNNHandler { ...@@ -213,7 +213,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
} }
dev_ctx_.SetBlob(local_key, target_memory_p); dev_ctx_.SetBlob(local_key, target_memory_p);
} else { } else if (!is_test) {
// Make reorder if needed // Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册