# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #-*- coding: utf-8 -*- import paddle.fluid as fluid import parl from parl import layers class Model(parl.Model): def __init__(self, act_dim): self.actor_model = ActorModel(act_dim) self.critic_model = CriticModel() def policy(self, obs): return self.actor_model.policy(obs) def value(self, obs, act): return self.critic_model.value(obs, act) def get_actor_params(self): return self.actor_model.parameters() class ActorModel(parl.Model): def __init__(self, act_dim): hid_size = 100 self.fc1 = layers.fc(size=hid_size, act='relu') self.fc2 = layers.fc(size=act_dim, act='tanh') def policy(self, obs): hid = self.fc1(obs) means = self.fc2(hid) return means class CriticModel(parl.Model): def __init__(self): hid_size = 100 self.fc1 = layers.fc(size=hid_size, act='relu') self.fc2 = layers.fc(size=1, act=None) def value(self, obs, act): concat = layers.concat([obs, act], axis=1) hid = self.fc1(concat) Q = self.fc2(hid) Q = layers.squeeze(Q, axes=[1]) return Q