提交 dee7f813 编写于 作者: T tensor-tang

add finish work of mkldnn

上级 44ed21ee
...@@ -21,6 +21,10 @@ limitations under the License. */ ...@@ -21,6 +21,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#ifdef PADDLE_USE_MKLDNN
#include "paddle/gserver/layers/MKLDNNLayer.h"
#endif
#ifndef PADDLE_MOBILE_INFERENCE #ifndef PADDLE_MOBILE_INFERENCE
#include "MultiNetwork.h" #include "MultiNetwork.h"
#include "RecurrentGradientMachine.h" #include "RecurrentGradientMachine.h"
...@@ -300,6 +304,17 @@ void NeuralNetwork::backward(const UpdateCallback& callback) { ...@@ -300,6 +304,17 @@ void NeuralNetwork::backward(const UpdateCallback& callback) {
} }
} }
void NeuralNetwork::finish() {
#ifdef PADDLE_USE_MKLDNN
FOR_EACH_R(layer, layers_) {
MKLDNNLayerPtr dnnLayer = std::dynamic_pointer_cast<MKLDNNLayer>(*layer);
if (dnnLayer) {
dnnLayer->convertWeightsToPaddle();
}
}
#endif
}
Argument NeuralNetwork::getLayerOutput(const std::string& layerName) { Argument NeuralNetwork::getLayerOutput(const std::string& layerName) {
return getLayer(layerName)->getOutput(); return getLayer(layerName)->getOutput();
} }
......
...@@ -134,6 +134,9 @@ public: ...@@ -134,6 +134,9 @@ public:
const std::string& getName() const { return subModelName_; } const std::string& getName() const { return subModelName_; }
/// some finish work, like convert the weight format of MKLDNNLayers
void finish() override;
protected: protected:
/** /**
* The constructor of NeuralNetwork. * The constructor of NeuralNetwork.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册