Pwil

pwil_irl_model

PwilRewardModel

class ding.reward_model.pwil_irl_model.PwilRewardModel(config: Dict, device: str, tb_logger: SummaryWriter)[source]
Overview:

The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf)

Interface:

estimate, train, load_expert_data, collect_data, clear_date, __init__, _train, _get_state_distance, _get_action_distance

Properties:
  • reward_table (:obj: Dict): In this algorithm, reward model is a dictionary.

__init__(config: Dict, device: str, tb_logger: SummaryWriter) None[source]
Overview:

Initialize self. See help(type(self)) for accurate signature.

Arguments:
  • cfg (Dict): Training config

  • device (str): Device usage, i.e. “cpu” or “cuda”

  • tb_logger (str): Logger, defaultly set as ‘SummaryWriter’ for model summary

_get_action_distance(a1: list, a2: list) torch.Tensor[source]
Overview:

Getting distances of actions given 2 action lists. One single action is of shape torch.Size([n]) (n referred in in-code comments)

Arguments:
  • a1 (torch.Tensor list): the 1st actions’ list of size M

  • a2 (torch.Tensor list): the 2nd actions’ list of size N

Returns:
  • distance (torch.Tensor) Euclidean distance tensor of the action tensor lists, of size M x N.

_get_state_distance(s1: list, s2: list) torch.Tensor[source]
Overview:

Getting distances of states given 2 state lists. One single state is of shape torch.Size([n]) (n referred in in-code comments)

Arguments:
  • s1 (torch.Tensor list): the 1st states’ list of size M

  • s2 (torch.Tensor list): the 2nd states’ list of size N

Returns:
  • distance (torch.Tensor) Euclidean distance tensor of the state tensor lists, of size M x N.

clear_data() None[source]
Overview:

Clearing training data. This is a side effect function which clears the data attribute in self

collect_data(data: list) None[source]
Overview:

Collecting training data formatted by fn:concat_state_action_pairs.

Arguments:
  • data (list): Raw training data (e.g. some form of states, actions, obs, etc)

Effects:
  • This is a side effect function which updates the data attribute in self; in this algorithm, also the s_size, a_size for states and actions are updated in the attribute in self.cfg Dict; reward_factor also updated as collect_data called.

estimate(data: list) None[source]
Overview:

Estimate reward by rewriting the reward key in each row of the data.

Arguments:
  • data (list): the list of data used for estimation, with at least obs and action keys.

Effects:
  • This is a side effect function which updates the reward_table with (obs,action) tuples from input.

load_expert_data() None[source]
Overview:

Getting the expert data from config['expert_data_path'] attribute in self

Effects:

This is a side effect function which updates the expert data attribute (e.g. self.expert_data); in this algorithm, also the self.expert_s, self.expert_a for states and actions are updated.

train() None[source]
Overview:

Training the Pwil reward model.

collect_state_action_pairs

Overview:

Concate state and action pairs from input iterator.

Arguments:
  • iterator (Iterable): Iterables with at least obs and action tensor keys.

Returns:
  • res (Torch.tensor): State and action pairs.