未验证 提交 6fa8a94b 编写于 作者: D dzhwinter 提交者: GitHub

Merge pull request #824 from dzhwinter/inference

"add inference"
......@@ -36,9 +36,10 @@ class PolicyGradient:
act="tanh" # tanh activation
)
# fc2
self.all_act_prob = fluid.layers.fc(input=fc1,
all_act_prob = fluid.layers.fc(input=fc1,
size=self.n_actions,
act="softmax")
self.inferece_program = fluid.defaul_main_program().clone()
# to maximize total reward (log_p * R) is to minimize -(log_p * R)
neg_log_prob = fluid.layers.cross_entropy(
input=self.all_act_prob,
......@@ -53,7 +54,7 @@ class PolicyGradient:
def choose_action(self, observation):
prob_weights = self.exe.run(
fluid.default_main_program().prune(self.all_act_prob),
self.inferece_program,
feed={"obs": observation[np.newaxis, :]},
fetch_list=[self.all_act_prob])
prob_weights = np.array(prob_weights[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册