robopal.wrappers.single_env_wrapper 源代码

import numpy as np
from robopal.envs.bimanual_tasks.bimanual_manipulate import BimanualManipulate


[文档] class SingleEnvWrapper: """ Convert a parallel multi-agents environment to a single-agent environment :param pallel_env: The parallel multi-agents environment :param mode: The mode of the single-agent environment, possible values are 'combined' and 'separate' """ def __init__(self, pallel_env: BimanualManipulate, mode="combined") -> None: self._parallel_env = pallel_env self.mode = mode assert mode in ['combined', 'shared'], "The mode should be 'combined' or 'shared'" self.action_dim = (len(self._parallel_env.agents) * self._parallel_env.action_dim[0],) if mode == 'combined': self.obs_dim = (len(self._parallel_env.agents) * self._parallel_env.obs_dim[0],) else: self.obs_dim = self._parallel_env.obs_dim
[文档] def step(self, action: np.ndarray): # convert array type action to dict type action actions = {agent: action[i*self._parallel_env.action_dim[0]:(i+1)*self._parallel_env.action_dim[0]] for i, agent in enumerate(self._parallel_env.agents)} observations, rewards, terminations, truncations, infos = self._parallel_env.step(actions) if self.mode == 'combined': return ( np.concatenate([observations[agent] for agent in self._parallel_env.agents]), np.sum(list(rewards.values())), all(terminations.values()), all(truncations.values()), {agent: infos[agent] for agent in self._parallel_env.agents} ) elif self.mode == 'shared': # only return the information of the first agent return ( observations[self._parallel_env.agents[0]], rewards[self._parallel_env.agents[0]], terminations[self._parallel_env.agents[0]], truncations[self._parallel_env.agents[0]], infos[self._parallel_env.agents[0]] ) else: raise ValueError(f"Unsupported mode: {self.mode}")
[文档] def reset(self, **kwargs): observations, infos = self._parallel_env.reset(**kwargs) return ( observations[self._parallel_env.agents[0]], infos[self._parallel_env.agents[0]] )
[文档] def close(self): self._parallel_env.close()
[文档] def render(self, mode='human'): self._parallel_env.render(mode)
def __getattr__(self, attr): if attr not in self.__dict__: return getattr(self._parallel_env, attr) else: return getattr(self, attr) @property def unwrapped(self): return self._parallel_env