提交 93608359 编写于 作者: H hedaoyuan

Fix UND AgentLayer.

上级 cad7bd13
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "hl_gpu.h" #include "hl_gpu.h"
#include "paddle/gserver/layers/AgentLayer.h"
#include "paddle/utils/CustomStackTrace.h" #include "paddle/utils/CustomStackTrace.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
...@@ -28,6 +27,7 @@ limitations under the License. */ ...@@ -28,6 +27,7 @@ limitations under the License. */
#ifndef PADDLE_MOBILE_INFERENCE #ifndef PADDLE_MOBILE_INFERENCE
#include "MultiNetwork.h" #include "MultiNetwork.h"
#include "RecurrentGradientMachine.h" #include "RecurrentGradientMachine.h"
#include "paddle/gserver/layers/AgentLayer.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config, ...@@ -192,9 +192,11 @@ void NeuralNetwork::init(const ModelConfig& config,
void NeuralNetwork::connect(LayerPtr agentLayer, void NeuralNetwork::connect(LayerPtr agentLayer,
LayerPtr realLayer, LayerPtr realLayer,
int height) { int height) {
#ifndef PADDLE_MOBILE_INFERENCE
AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get()); AgentLayer* agent = dynamic_cast<AgentLayer*>(agentLayer.get());
CHECK_NOTNULL(agent); CHECK_NOTNULL(agent);
agent->setRealLayer(realLayer, height); agent->setRealLayer(realLayer, height);
#endif
} }
void NeuralNetwork::connect(std::string agentLayerName, void NeuralNetwork::connect(std::string agentLayerName,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册