提交 1658958f 编写于 作者: K Krzysztof Binias

Reusing converted weights

上级 a5576083
......@@ -130,12 +130,13 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
const std::vector<mkldnn::primitive>& pipeline,
bool is_test = false) { // NOLINT
auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
auto weights_pd = conv_pd_->weights_primitive_desc();
return this->AcquireMemory(weights_pd, user_weights_pd,
user_weights_memory_p, "@weights_mem_p",
pipeline);
pipeline, is_test);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
......@@ -266,6 +267,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -371,7 +374,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline);
user_weights_memory_p, pipeline, is_test);
auto dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
......
......@@ -109,6 +109,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
}
void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", "").SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......
......@@ -191,8 +191,8 @@ class MKLDNNHandler {
mkldnn::memory::primitive_desc& mpd, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
const std::string& suffix, const std::vector<mkldnn::primitive>& pipeline,
bool is_test = false) { // NOLINT
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
......@@ -213,7 +213,7 @@ class MKLDNNHandler {
pipeline.push_back(*reorder_p);
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else {
} else if (!is_test) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册