Gymnasium Style Interface

我们的环境遵循 OpenAI Gymnasium 接口规范,可以方便地搭建强化学习训练环境。

下面是一个简单的示例,演示了使用 Stable Baselines3 训练 Pick-and-Place 任务。 其中,使用的环境来自于 robopal/demos/single_task_manipulation 目录下的 PickAndPlaceEnv 环境。

from stable_baselines3 import HerReplayBuffer
from sb3_contrib import TQC
from stable_baselines3.common.callbacks import BaseCallback
from robopal.demos.single_task_manipulation import PickAndPlaceEnv
from robopal.commons.gym_wrapper import GoalEnvWrapper


class TensorboardCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, log_dir, verbose=0):
        super(TensorboardCallback, self).__init__(verbose)
        self.log_dir = log_dir

    def _on_step(self) -> bool:
        if self.n_calls % 51200 == 0:
            self.model.save(self.log_dir + f"/model_saved/TQC/diana_pick_place_v2_{self.n_calls}")
        return True


log_dir = "log/"

env = PickAndPlaceEnv(
    render_mode="human",
    control_freq=10,
    controller='JNTIMP',
)
env = GoalEnvWrapper(env)

# Initialize the model
model = TQC(
    'MultiInputPolicy',
    env,
    replay_buffer_class=HerReplayBuffer,
    # Parameters for HER
    replay_buffer_kwargs=dict(
        n_sampled_goal=4,
        goal_selection_strategy="future",
    ),
    verbose=1,
    tensorboard_log=log_dir,
    batch_size=1024,
    gamma=0.95,
    tau=0.05,
    policy_kwargs=dict(n_critics=2, net_arch=[512, 512, 512]),
)

# Train the model
model.learn(int(1e6), callback=TensorboardCallback(log_dir=log_dir))

model.save("./her_bit_env")

下·