未验证 提交 6fa8a94b 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #824 from dzhwinter/inference

"add inference"
...@@ -36,9 +36,10 @@ class PolicyGradient: ...@@ -36,9 +36,10 @@ class PolicyGradient:
act="tanh" # tanh activation act="tanh" # tanh activation
) )
# fc2 # fc2
self.all_act_prob = fluid.layers.fc(input=fc1, all_act_prob = fluid.layers.fc(input=fc1,
size=self.n_actions, size=self.n_actions,
act="softmax") act="softmax")
self.inferece_program = fluid.defaul_main_program().clone()
# to maximize total reward (log_p * R) is to minimize -(log_p * R) # to maximize total reward (log_p * R) is to minimize -(log_p * R)
neg_log_prob = fluid.layers.cross_entropy( neg_log_prob = fluid.layers.cross_entropy(
input=self.all_act_prob, input=self.all_act_prob,
...@@ -53,7 +54,7 @@ class PolicyGradient: ...@@ -53,7 +54,7 @@ class PolicyGradient:
def choose_action(self, observation): def choose_action(self, observation):
prob_weights = self.exe.run( prob_weights = self.exe.run(
fluid.default_main_program().prune(self.all_act_prob), self.inferece_program,
feed={"obs": observation[np.newaxis, :]}, feed={"obs": observation[np.newaxis, :]},
fetch_list=[self.all_act_prob]) fetch_list=[self.all_act_prob])
prob_weights = np.array(prob_weights[0]) prob_weights = np.array(prob_weights[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册