PokéLLMon 源码解析(四)

2024-03-08 09:16:00 浏览数 (2)

.PokeLLMonpoke_envexceptions.py

代码语言:javascript复制
"""
This module contains exceptions.
"""

# 定义一个自定义异常类 ShowdownException,继承自内置异常类 Exception
class ShowdownException(Exception):
    """
    This exception is raised when a non-managed message
    is received from the server.
    """
    # 当从服务器接收到非受控消息时引发此异常

    pass

.PokeLLMonpoke_envplayerbaselines.py

代码语言:javascript复制
# 导入必要的模块
from typing import List
import json
import os

from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.environment.double_battle import DoubleBattle
from poke_env.environment.move_category import MoveCategory
from poke_env.environment.pokemon import Pokemon
from poke_env.environment.side_condition import SideCondition
from poke_env.player.player import Player
from poke_env.data.gen_data import GenData

# 从文件中加载招式效果数据
with open("./poke_env/data/static/moves/moves_effect.json", "r") as f:
    move_effect = json.load(f)

# 计算招式类型的伤害倍率
def calculate_move_type_damage_multipier(type_1, type_2, type_chart, constraint_type_list):
    # 定义所有可能的宝可梦类型
    TYPE_list = 'BUG,DARK,DRAGON,ELECTRIC,FAIRY,FIGHTING,FIRE,FLYING,GHOST,GRASS,GROUND,ICE,NORMAL,POISON,PSYCHIC,ROCK,STEEL,WATER'.split(",")

    move_type_damage_multiplier_list = []

    # 如果存在第二个类型
    if type_2:
        # 计算每种类型对应的伤害倍率
        for type in TYPE_list:
            move_type_damage_multiplier_list.append(type_chart[type_1][type] * type_chart[type_2][type])
        move_type_damage_multiplier_dict = dict(zip(TYPE_list, move_type_damage_multiplier_list))
    else:
        move_type_damage_multiplier_dict = type_chart[type_1]

    effective_type_list = []
    extreme_type_list = []
    resistant_type_list = []
    extreme_resistant_type_list = []
    immune_type_list = []
    # 根据伤害倍率将类型分为不同的类别
    for type, value in move_type_damage_multiplier_dict.items():
        if value == 2:
            effective_type_list.append(type)
        elif value == 4:
            extreme_type_list.append(type)
        elif value == 1 / 2:
            resistant_type_list.append(type)
        elif value == 1 / 4:
            extreme_resistant_type_list.append(type)
        elif value == 0:
            immune_type_list.append(type)
        else:  # value == 1
            continue
    # 如果约束类型列表不为空
    if constraint_type_list:
        # 更新极端类型列表,取交集
        extreme_type_list = list(set(extreme_type_list).intersection(set(constraint_type_list)))
        # 更新有效类型列表,取交集
        effective_type_list = list(set(effective_type_list).intersection(set(constraint_type_list)))
        # 更新抗性类型列表,取交集
        resistant_type_list = list(set(resistant_type_list).intersection(set(constraint_type_list)))
        # 更新极端抗性类型列表,取交集
        extreme_resistant_type_list = list(set(extreme_resistant_type_list).intersection(set(constraint_type_list)))
        # 更新免疫类型列表,取交集
        immune_type_list = list(set(immune_type_list).intersection(set(constraint_type_list)))

    # 返回更新后的各类型列表
    return extreme_type_list, effective_type_list, resistant_type_list, extreme_resistant_type_list, immune_type_list
# 定义一个函数,根据给定的参数计算并返回对应的移动类型伤害提示
def move_type_damage_wraper(pokemon_name, type_1, type_2, type_chart, constraint_type_list=None):

    # 初始化移动类型伤害提示字符串
    move_type_damage_prompt = ""
    
    # 调用函数计算移动类型伤害倍数,得到各种类型的列表
    extreme_effective_type_list, effective_type_list, resistant_type_list, extreme_resistant_type_list, immune_type_list = calculate_move_type_damage_multipier(
        type_1, type_2, type_chart, constraint_type_list)

    # 如果存在有效的、抵抗的或免疫的类型列表
    if effective_type_list or resistant_type_list or immune_type_list:

        # 构建移动类型伤害提示字符串
        move_type_damage_prompt = f"{pokemon_name}"
        if extreme_effective_type_list:
            move_type_damage_prompt = move_type_damage_prompt   " can be super-effectively attacked by "   ", ".join(
                extreme_effective_type_list)   " moves"
        if effective_type_list:
            move_type_damage_prompt = move_type_damage_prompt   ", can be effectively attacked by "   ", ".join(
                effective_type_list)   " moves"
        if resistant_type_list:
            move_type_damage_prompt = move_type_damage_prompt   ", is resistant to "   ", ".join(
                resistant_type_list)   " moves"
        if extreme_resistant_type_list:
            move_type_damage_prompt = move_type_damage_prompt   ", is super-resistant to "   ", ".join(
                extreme_resistant_type_list)   " moves"
        if immune_type_list:
            move_type_damage_prompt = move_type_damage_prompt   ", is immuned to "   ", ".join(
                immune_type_list)   " moves"

    # 返回移动类型伤害提示字符串
    return move_type_damage_prompt


# 定义一个类,继承自Player类,实现最大基础伤害玩家
class MaxBasePowerPlayer(Player):
    
    # 重写choose_move方法
    def choose_move(self, battle: AbstractBattle):
        # 如果存在可用的移动
        if battle.available_moves:
            # 选择基础伤害最大的移动
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)
        # 如果没有可用的移动,则随机选择一个移动
        return self.choose_random_move(battle)

# 定义一个类,继承自Player类,实现简单启发式玩家
class SimpleHeuristicsPlayer(Player):
    # 定义了各种入场危害效果,将字符串映射到对应的SideCondition枚举值
    ENTRY_HAZARDS = {
        "spikes": SideCondition.SPIKES,
        "stealhrock": SideCondition.STEALTH_ROCK,
        "stickyweb": SideCondition.STICKY_WEB,
        "toxicspikes": SideCondition.TOXIC_SPIKES,
    }

    # 定义了反危害招式,使用集合存储
    ANTI_HAZARDS_MOVES = {"rapidspin", "defog"}

    # 定义了速度等级系数
    SPEED_TIER_COEFICIENT = 0.1
    # 定义了生命值分数系数
    HP_FRACTION_COEFICIENT = 0.4
    # 定义了交换出场匹配阈值
    SWITCH_OUT_MATCHUP_THRESHOLD = -2

    # 估算对战情况,返回得分
    def _estimate_matchup(self, mon: Pokemon, opponent: Pokemon):
        # 计算对手对我方造成的伤害倍率的最大值
        score = max([opponent.damage_multiplier(t) for t in mon.types if t is not None])
        # 减去我方对对手造成的伤害倍率的最大值
        score -= max(
            [mon.damage_multiplier(t) for t in opponent.types if t is not None]
        )
        # 根据速度等级差异调整得分
        if mon.base_stats["spe"] > opponent.base_stats["spe"]:
            score  = self.SPEED_TIER_COEFICIENT
        elif opponent.base_stats["spe"] > mon.base_stats["spe"]:
            score -= self.SPEED_TIER_COEFICIENT

        # 根据生命值分数调整得分
        score  = mon.current_hp_fraction * self.HP_FRACTION_COEFICIENT
        score -= opponent.current_hp_fraction * self.HP_FRACTION_COEFICIENT

        return score

    # 判断是否应该使用极巨化
    def _should_dynamax(self, battle: AbstractBattle, n_remaining_mons: int):
        if battle.can_dynamax and self._dynamax_disable is False:
            # 最后一个满血的精灵
            if (
                len([m for m in battle.team.values() if m.current_hp_fraction == 1])
                == 1
                and battle.active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 有优势且双方都是满血
            if (
                self._estimate_matchup(
                    battle.active_pokemon, battle.opponent_active_pokemon
                )
                > 0
                and battle.active_pokemon.current_hp_fraction == 1
                and battle.opponent_active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 只剩下一个精灵
            if n_remaining_mons == 1:
                return True
        return False
    # 判断是否应该替换出当前精灵
    def _should_switch_out(self, battle: AbstractBattle):
        # 获取当前精灵和对手精灵
        active = battle.active_pokemon
        opponent = battle.opponent_active_pokemon
        # 如果有一个适合替换的精灵...
        if [
            m
            for m in battle.available_switches
            if self._estimate_matchup(m, opponent) > 0
        ]:
            # ...并且有一个“好”的理由替换出去
            if active.boosts["def"] <= -3 or active.boosts["spd"] <= -3:
                return True
            if (
                active.boosts["atk"] <= -3
                and active.stats["atk"] >= active.stats["spa"]
            ):
                return True
            if (
                active.boosts["spa"] <= -3
                and active.stats["atk"] <= active.stats["spa"]
            ):
                return True
            if (
                self._estimate_matchup(active, opponent)
                < self.SWITCH_OUT_MATCHUP_THRESHOLD
            ):
                return True
        return False

    # 估算精灵的状态
    def _stat_estimation(self, mon: Pokemon, stat: str):
        # 计算状态提升值
        if mon.boosts[stat] > 1:
            boost = (2   mon.boosts[stat]) / 2
        else:
            boost = 2 / (2 - mon.boosts[stat])
        return ((2 * mon.base_stats[stat]   31)   5) * boost

    # 计算奖励值
    def calc_reward(
            self, current_battle: AbstractBattle
    ) -> float:
        # 计算奖励值
        return self.reward_computing_helper(
            current_battle, fainted_value=2.0, hp_value=1.0, victory_value=30.0
        )
    # 根据状态和等级返回加成倍数
    def boost_multiplier(self, state, level):
        # 如果状态是准确度
        if state == "accuracy":
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.33
            if level == 2:
                return 1.66
            if level == 3:
                return 2.0
            if level == 4:
                return 2.5
            if level == 5:
                return 2.66
            if level == 6:
                return 3.0
            if level == -1:
                return 0.75
            if level == -2:
                return 0.6
            if level == -3:
                return 0.5
            if level == -4:
                return 0.43
            if level == -5:
                return 0.36
            if level == -6:
                return 0.33
        # 如果状态不是准确度
        else:
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.5
            if level == 2:
                return 2.0
            if level == 3:
                return 2.5
            if level == 4:
                return 3.0
            if level == 5:
                return 3.5
            if level == 6:
                return 4.0
            if level == -1:
                return 0.67
            if level == -2:
                return 0.5
            if level == -3:
                return 0.4
            if level == -4:
                return 0.33
            if level == -5:
                return 0.29
            if level == -6:
                return 0.25
    # 检查给定状态的值,并返回相应的状态字符串
    def check_status(self, status):
        # 如果状态存在
        if status:
            # 根据状态值返回相应的状态字符串
            if status.value == 1:
                return "burnt"
            elif status.value == 2:
                return "fainted"
            elif status.value == 3:
                return "frozen"
            elif status.value == 4:
                return "paralyzed"
            elif status.value == 5:
                return "poisoned"
            elif status.value == 7:
                return "toxic"
            elif status.value == 6:
                return "sleeping"
        # 如果状态不存在,则返回"healthy"
        else:
            return "healthy"

.PokeLLMonpoke_envplayerbattle_order.py

代码语言:javascript复制
# 从 dataclasses 模块中导入 dataclass 装饰器
# 从 typing 模块中导入 Any, List, Optional, Union 类型
# 从 poke_env.environment.double_battle 模块中导入 DoubleBattle 类
# 从 poke_env.environment.move 模块中导入 Move 类
# 从 poke_env.environment.pokemon 模块中导入 Pokemon 类
from dataclasses import dataclass
from typing import Any, List, Optional, Union

# 定义一个名为 BattleOrder 的数据类
@dataclass
class BattleOrder:
    # order 属性可以是 Move 或 Pokemon 类型,初始值为 None
    order: Optional[Union[Move, Pokemon]]
    # mega, z_move, dynamax, terastallize, move_target 属性的默认值
    mega: bool = False
    z_move: bool = False
    dynamax: bool = False
    terastallize: bool = False
    move_target: int = DoubleBattle.EMPTY_TARGET_POSITION

    # 默认的指令字符串
    DEFAULT_ORDER = "/choose default"

    # 返回对象的字符串表示形式
    def __str__(self) -> str:
        return self.message

    # 返回消息字符串
    @property
    def message(self) -> str:
        # 如果 order 是 Move 类型
        if isinstance(self.order, Move):
            # 如果 order 的 id 是 "recharge"
            if self.order.id == "recharge":
                return "/choose move 1"

            # 构建消息字符串
            message = f"/choose move {self.order.id}"
            if self.mega:
                message  = " mega"
            elif self.z_move:
                message  = " zmove"
            elif self.dynamax:
                message  = " dynamax"
            elif self.terastallize:
                message  = " terastallize"

            # 如果 move_target 不是空目标位置
            if self.move_target != DoubleBattle.EMPTY_TARGET_POSITION:
                message  = f" {self.move_target}"
            return message
        # 如果 order 是 Pokemon 类型
        elif isinstance(self.order, Pokemon):
            return f"/choose switch {self.order.species}"
        else:
            return ""

# 定义一个名为 DefaultBattleOrder 的类,继承自 BattleOrder 类
class DefaultBattleOrder(BattleOrder):
    # 初始化方法,不执行任何操作
    def __init__(self, *args: Any, **kwargs: Any):
        pass

    # 返回默认指令字符串
    @property
    def message(self) -> str:
        return self.DEFAULT_ORDER

# 定义一个名为 DoubleBattleOrder 的数据类,继承自 BattleOrder 类
@dataclass
class DoubleBattleOrder(BattleOrder):
    # 初始化方法,接受两个可选的 BattleOrder 参数
    def __init__(
        self,
        first_order: Optional[BattleOrder] = None,
        second_order: Optional[BattleOrder] = None,
    ):
        self.first_order = first_order
        self.second_order = second_order

    # 返回消息字符串
    @property
    # 返回合并后的消息字符串
    def message(self) -> str:
        # 如果存在第一和第二指令,则返回两者消息的组合
        if self.first_order and self.second_order:
            return (
                self.first_order.message
                  ", "
                  self.second_order.message.replace("/choose ", "")
            )
        # 如果只存在第一指令,则返回第一指令消息和默认消息的组合
        elif self.first_order:
            return self.first_order.message   ", default"
        # 如果只存在第二指令,则返回第二指令消息和默认消息的组合
        elif self.second_order:
            return self.second_order.message   ", default"
        # 如果都不存在指令,则返回默认指令消息
        else:
            return self.DEFAULT_ORDER

    # 静态方法,用于合并第一和第二指令列表生成双重战斗指令列表
    @staticmethod
    def join_orders(first_orders: List[BattleOrder], second_orders: List[BattleOrder]):
        # 如果第一和第二指令列表都存在
        if first_orders and second_orders:
            # 生成双重战斗指令列表,排除特定条件下的指令
            orders = [
                DoubleBattleOrder(first_order=first_order, second_order=second_order)
                for first_order in first_orders
                for second_order in second_orders
                if not first_order.mega or not second_order.mega
                if not first_order.z_move or not second_order.z_move
                if not first_order.dynamax or not second_order.dynamax
                if not first_order.terastallize or not second_order.terastallize
                if first_order.order != second_order.order
            ]
            # 如果生成了双重战斗指令列表,则返回该列表
            if orders:
                return orders
        # 如果只存在第一指令列表,则生成只包含第一指令的双重战斗指令列表
        elif first_orders:
            return [DoubleBattleOrder(first_order=order) for order in first_orders]
        # 如果只存在第二指令列表,则生成只包含第二指令的双重战斗指令列表
        elif second_orders:
            return [DoubleBattleOrder(first_order=order) for order in second_orders]
        # 如果两个指令列表都不存在,则返回只包含默认指令的双重战斗指令列表
        return [DefaultBattleOrder()]
# 定义一个名为ForfeitBattleOrder的类,继承自BattleOrder类
class ForfeitBattleOrder(BattleOrder):
    # 初始化方法,接受任意数量的位置参数和关键字参数
    def __init__(self, *args: Any, **kwargs: Any):
        # pass表示不做任何操作,保持方法的结构完整
        pass

    # 定义一个名为message的属性,返回字符串"/forfeit"
    @property
    def message(self) -> str:
        return "/forfeit"

.PokeLLMonpoke_envplayergpt_player.py

代码语言:javascript复制
import json  # 导入 json 模块
import os  # 导入 os 模块
import random  # 导入 random 模块
from typing import List  # 导入 List 类型提示
from poke_env.environment.abstract_battle import AbstractBattle  # 导入 AbstractBattle 类
from poke_env.environment.double_battle import DoubleBattle  # 导入 DoubleBattle 类
from poke_env.environment.move_category import MoveCategory  # 导入 MoveCategory 类
from poke_env.environment.pokemon import Pokemon  # 导入 Pokemon 类
from poke_env.environment.side_condition import SideCondition  # 导入 SideCondition 类
from poke_env.player.player import Player, BattleOrder  # 导入 Player 和 BattleOrder 类
from typing import Dict, List, Optional, Union  # 导入 Dict, List, Optional, Union 类型提示
from poke_env.environment.move import Move  # 导入 Move 类
import time  # 导入 time 模块
import json  # 再次导入 json 模块(重复导入)
from openai import OpenAI  # 导入 OpenAI 类
from poke_env.data.gen_data import GenData  # 导入 GenData 类

def calculate_move_type_damage_multipier(type_1, type_2, type_chart, constraint_type_list):
    TYPE_list = 'BUG,DARK,DRAGON,ELECTRIC,FAIRY,FIGHTING,FIRE,FLYING,GHOST,GRASS,GROUND,ICE,NORMAL,POISON,PSYCHIC,ROCK,STEEL,WATER'.split(",")

    move_type_damage_multiplier_list = []  # 初始化一个空列表,用于存储每种类型的伤害倍率

    if type_2:  # 如果存在第二种类型
        for type in TYPE_list:  # 遍历每种类型
            move_type_damage_multiplier_list.append(type_chart[type_1][type] * type_chart[type_2][type])  # 计算两种类型之间的伤害倍率并添加到列表中
        move_type_damage_multiplier_dict = dict(zip(TYPE_list, move_type_damage_multiplier_list))  # 将类型和对应的伤害倍率组成字典
    else:  # 如果只有一种类型
        move_type_damage_multiplier_dict = type_chart[type_1]  # 直接使用第一种类型的伤害倍率字典

    effective_type_list = []  # 初始化有效类型列表
    extreme_type_list = []  # 初始化极效类型列表
    resistant_type_list = []  # 初始化抵抗类型列表
    extreme_resistant_type_list = []  # 初始化极度抵抗类型列表
    immune_type_list = []  # 初始化免疫类型列表
    for type, value in move_type_damage_multiplier_dict.items():  # 遍历每种类型及其对应的伤害倍率
        if value == 2:  # 如果伤害倍率为 2
            effective_type_list.append(type)  # 添加到有效类型列表
        elif value == 4:  # 如果伤害倍率为 4
            extreme_type_list.append(type)  # 添加到极效类型列表
        elif value == 1 / 2:  # 如果伤害倍率为 1/2
            resistant_type_list.append(type)  # 添加到抵抗类型列表
        elif value == 1 / 4:  # 如果伤害倍率为 1/4
            extreme_resistant_type_list.append(type)  # 添加到极度抵抗类型列表
        elif value == 0:  # 如果伤害倍率为 0
            immune_type_list.append(type)  # 添加到免疫类型列表
        else:  # 如果伤害倍率为 1
            continue  # 继续循环
    # 如果约束类型列表不为空
    if constraint_type_list:
        # 将极端类型列表与约束类型列表的交集作为新的极端类型列表
        extreme_type_list = list(set(extreme_type_list).intersection(set(constraint_type_list)))
        # 将有效类型列表与约束类型列表的交集作为新的有效类型列表
        effective_type_list = list(set(effective_type_list).intersection(set(constraint_type_list)))
        # 将抗性类型列表与约束类型列表的交集作为新的抗性类型列表
        resistant_type_list = list(set(resistant_type_list).intersection(set(constraint_type_list)))
        # 将极端抗性类型列表与约束类型列表的交集作为新的极端抗性类型列表
        extreme_resistant_type_list = list(set(extreme_resistant_type_list).intersection(set(constraint_type_list)))
        # 将免疫类型列表与约束类型列表的交集作为新的免疫类型列表
        immune_type_list = list(set(immune_type_list).intersection(set(constraint_type_list)))

    # 返回各类型列表的首字母大写形式
    return (list(map(lambda x: x.capitalize(), extreme_type_list)),
           list(map(lambda x: x.capitalize(), effective_type_list)),
           list(map(lambda x: x.capitalize(), resistant_type_list)),
           list(map(lambda x: x.capitalize(), extreme_resistant_type_list)),
           list(map(lambda x: x.capitalize(), immune_type_list)))
# 定义一个函数,用于计算给定精灵对应的移动类型伤害提示
def move_type_damage_wraper(pokemon, type_chart, constraint_type_list=None):

    # 初始化变量,用于存储精灵的两种类型
    type_1 = None
    type_2 = None
    # 如果精灵有第一种类型
    if pokemon.type_1:
        # 获取第一种类型的名称
        type_1 = pokemon.type_1.name
        # 如果精灵有第二种类型
        if pokemon.type_2:
            # 获取第二种类型的名称
            type_2 = pokemon.type_2.name

    # 初始化移动类型伤害提示字符串
    move_type_damage_prompt = ""
    # 调用函数计算移动类型伤害倍数,得到不同类型的列表
    extreme_effective_type_list, effective_type_list, resistant_type_list, extreme_resistant_type_list, immune_type_list = calculate_move_type_damage_multipier(
        type_1, type_2, type_chart, constraint_type_list)

    # 根据不同类型的列表生成移动类型伤害提示
    if extreme_effective_type_list:
        move_type_damage_prompt = (move_type_damage_prompt   " "   ", ".join(extreme_effective_type_list)  
                                   f"-type attack is extremely-effective (4x damage) to {pokemon.species}.")

    if effective_type_list:
        move_type_damage_prompt = (move_type_damage_prompt   " "   ", ".join(effective_type_list)  
                                   f"-type attack is super-effective (2x damage) to {pokemon.species}.")

    if resistant_type_list:
        move_type_damage_prompt = (move_type_damage_prompt   " "   ", ".join(resistant_type_list)  
                                   f"-type attack is ineffective (0.5x damage) to {pokemon.species}.")

    if extreme_resistant_type_list:
        move_type_damage_prompt = (move_type_damage_prompt   " "   ", ".join(extreme_resistant_type_list)  
                                   f"-type attack is highly ineffective (0.25x damage) to {pokemon.species}.")

    if immune_type_list:
        move_type_damage_prompt = (move_type_damage_prompt   " "   ", ".join(immune_type_list)  
                                   f"-type attack is zero effect (0x damage) to {pokemon.species}.")

    # 返回移动类型伤害提示字符串
    return move_type_damage_prompt


# 定义一个类,继承自Player类
class LLMPlayer(Player):
    # 使用 OpenAI API 进行对话生成,返回生成的文本
    def chatgpt(self, system_prompt, user_prompt, model, temperature=0.7, json_format=False, seed=None, stop=[], max_tokens=200) -> str:
        # 创建 OpenAI 客户端对象
        client = OpenAI(api_key=self.api_key)
        # 如果需要返回 JSON 格式的响应
        if json_format:
            # 调用 API 完成对话生成,返回 JSON 格式的响应
            response = client.chat.completions.create(
                response_format={"type": "json_object"},
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=temperature,
                stream=False,
                # seed=seed,
                stop=stop,
                max_tokens=max_tokens
            )
        else:
            # 调用 API 完成对话生成
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=temperature,
                stream=False,
                # seed=seed,
                max_tokens=max_tokens,
                stop=stop
            )
        # 获取生成的文本内容
        outputs = response.choices[0].message.content
        # 记录完成的 token 数量
        self.completion_tokens  = response.usage.completion_tokens
        # 记录 prompt 的 token 数量
        self.prompt_tokens  = response.usage.prompt_tokens

        # 返回生成的文本
        return outputs
    # 估算两只精灵之间的对战得分
    def _estimate_matchup(self, mon: Pokemon, opponent: Pokemon):
        # 计算对手对该精灵造成的伤害加成中的最大值
        score = max([opponent.damage_multiplier(t) for t in mon.types if t is not None])
        # 计算该精灵对对手造成的伤害加成中的最大值
        score -= max(
            [mon.damage_multiplier(t) for t in opponent.types if t is not None]
        )
        # 根据速度判断得分
        if mon.base_stats["spe"] > opponent.base_stats["spe"]:
            score  = self.SPEED_TIER_COEFICIENT
        elif opponent.base_stats["spe"] > mon.base_stats["spe"]:
            score -= self.SPEED_TIER_COEFICIENT

        # 根据当前生命值比例调整得分
        score  = mon.current_hp_fraction * self.HP_FRACTION_COEFICIENT
        score -= opponent.current_hp_fraction * self.HP_FRACTION_COEFICIENT

        return score

    # 判断是否应该使用极巨化
    def _should_dynamax(self, battle: AbstractBattle):
        # 统计队伍中剩余未倒下的精灵数量
        n_remaining_mons = len(
            [m for m in battle.team.values() if m.fainted is False]
        )
        if battle.can_dynamax and self._dynamax_disable is False:
            # 如果只剩下一只全血的精灵
            if (
                len([m for m in battle.team.values() if m.current_hp_fraction == 1])
                == 1
                and battle.active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 如果有对战优势且双方都是全血状态
            if (
                self._estimate_matchup(
                    battle.active_pokemon, battle.opponent_active_pokemon
                )
                > 0
                and battle.active_pokemon.current_hp_fraction == 1
                and battle.opponent_active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 如果只剩下一只精灵
            if n_remaining_mons == 1:
                return True
        return False
    # 解析LLM输出,找到JSON内容的起始位置
    json_start = llm_output.find('{')
    # 找到JSON内容的结束位置,从后往前找第一个}
    json_end = llm_output.rfind('}')   1
    # 提取JSON内容
    json_content = llm_output[json_start:json_end]
    # 将JSON内容加载为Python对象
    llm_action_json = json.loads(json_content)
    # 初始化下一个动作为None
    next_action = None
    
    # 如果JSON中包含"move"字段
    if "move" in llm_action_json.keys():
        # 获取LLM中的移动ID并处理格式
        llm_move_id = llm_action_json["move"]
        llm_move_id = llm_move_id.replace(" ","").replace("-", "")
        # 遍历可用的移动列表,匹配LLM中的移动ID
        for i, move in enumerate(battle.available_moves):
            if move.id.lower() == llm_move_id.lower():
                # 创建相应的移动指令
                next_action = self.create_order(move, dynamax=self._should_dynamax(battle))

    # 如果JSON中包含"switch"字段
    elif "switch" in llm_action_json.keys():
        # 获取LLM中的交换精灵种类并匹配可用的交换精灵列表
        llm_switch_species = llm_action_json["switch"]
        for i, pokemon in enumerate(battle.available_switches):
            if pokemon.species.lower() == llm_switch_species.lower():
                # 创建相应的交换指令
                next_action = self.create_order(pokemon)

    # 如果下一个动作仍为None,则抛出数值错误异常
    if next_action is None:
        raise ValueError("Value Error")
    # 返回下一个动作
    return next_action
    # 解析LLM输出,找到JSON内容的起始位置
    json_start = llm_output.find('{')
    # 找到JSON内容的结束位置,从后往前找第一个}
    json_end = llm_output.rfind('}')   1
    # 提取JSON内容
    json_content = llm_output[json_start:json_end]
    # 将JSON内容转换为Python对象
    llm_action_json = json.loads(json_content)
    next_action = None
    # 获取动作和目标
    action = llm_action_json["decision"]["action"]
    target = llm_action_json["decision"]["target"]
    # 处理目标字符串,去除空格和下划线
    target = target.replace(" ", "").replace("_", "")
    # 如果动作是移动
    if action.lower() == "move":
        # 遍历可用的移动
        for i, move in enumerate(battle.available_moves):
            # 如果移动ID匹配目标
            if move.id.lower() == target.lower():
                # 创建移动指令
                next_action = self.create_order(move, dynamax=self._should_dynamax(battle))

    # 如果动作是交换
    elif action.lower() == "switch":
        # 遍历可用的交换精灵
        for i, pokemon in enumerate(battle.available_switches):
            # 如果精灵种类匹配目标
            if pokemon.species.lower() == target.lower():
                # 创建交换指令
                next_action = self.create_order(pokemon)

    # 如果没有找到下一步动作,抛出数值错误
    if next_action is None:
        raise ValueError("Value Error")

    # 返回下一步动作
    return next_action

    # 检查状态并返回对应的字符串
    def check_status(self, status):
        if status:
            if status.value == 1:
                return "burnt"
            elif status.value == 2:
                return "fainted"
            elif status.value == 3:
                return "frozen"
            elif status.value == 4:
                return "paralyzed"
            elif status.value == 5:
                return "poisoned"
            elif status.value == 7:
                return "toxic"
            elif status.value == 6:
                return "sleeping"
        else:
            return ""
    # 根据状态和等级返回加成倍数
    def boost_multiplier(self, state, level):
        # 如果状态是准确度
        if state == "accuracy":
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.33
            if level == 2:
                return 1.66
            if level == 3:
                return 2.0
            if level == 4:
                return 2.5
            if level == 5:
                return 2.66
            if level == 6:
                return 3.0
            if level == -1:
                return 0.75
            if level == -2:
                return 0.6
            if level == -3:
                return 0.5
            if level == -4:
                return 0.43
            if level == -5:
                return 0.36
            if level == -6:
                return 0.33
        # 如果状态不是准确度
        else:
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.5
            if level == 2:
                return 2.0
            if level == 3:
                return 2.5
            if level == 4:
                return 3.0
            if level == 5:
                return 3.5
            if level == 6:
                return 4.0
            if level == -1:
                return 0.67
            if level == -2:
                return 0.5
            if level == -3:
                return 0.4
            if level == -4:
                return 0.33
            if level == -5:
                return 0.29
            if level == -6:
                return 0.25
    # 返回战斗摘要信息,包括击败得分、剩余得分、胜利列表和标签列表
    def battle_summary(self):
        
        # 初始化空列表用于存储击败得分、剩余得分、胜利列表和标签列表
        beat_list = []
        remain_list = []
        win_list = []
        tag_list = []
        
        # 遍历每场战斗,计算击败得分、剩余得分、是否胜利以及标签
        for tag, battle in self.battles.items():
            beat_score = 0
            # 计算对手队伍的击败得分
            for mon in battle.opponent_team.values():
                beat_score  = (1-mon.current_hp_fraction)

            beat_list.append(beat_score)

            remain_score = 0
            # 计算己方队伍的剩余得分
            for mon in battle.team.values():
                remain_score  = mon.current_hp_fraction

            remain_list.append(remain_score)
            # 如果战斗胜利,则在胜利列表中添加1
            if battle.won:
                win_list.append(1)

            tag_list.append(tag)

        # 返回击败得分列表、剩余得分列表、胜利列表和标签列表
        return beat_list, remain_list, win_list, tag_list

    # 辅助计算奖励值的函数
    def reward_computing_helper(
        self,
        battle: AbstractBattle,
        *,
        fainted_value: float = 0.0,
        hp_value: float = 0.0,
        number_of_pokemons: int = 6,
        starting_value: float = 0.0,
        status_value: float = 0.0,
        victory_value: float = 1.0,
    ) -> float:
        """A helper function to compute rewards."""

        # 如果战斗不在奖励缓冲区中,则将其添加,并设置初始值
        if battle not in self._reward_buffer:
            self._reward_buffer[battle] = starting_value
        current_value = 0

        # 遍历我方队伍中的每只精灵
        for mon in battle.team.values():
            # 根据当前生命值比例计算当前值
            current_value  = mon.current_hp_fraction * hp_value
            # 如果精灵已经倒下,则减去倒下值
            if mon.fainted:
                current_value -= fainted_value
            # 如果精灵有异常状态,则减去异常状态值
            elif mon.status is not None:
                current_value -= status_value

        # 根据己方队伍中精灵数量与总精灵数量的差值计算当前值
        current_value  = (number_of_pokemons - len(battle.team)) * hp_value

        # 遍历对方队伍中的每只精灵
        for mon in battle.opponent_team.values():
            # 根据当前生命值比例计算当前值
            current_value -= mon.current_hp_fraction * hp_value
            # 如果精灵已经倒下,则加上倒下值
            if mon.fainted:
                current_value  = fainted_value
            # 如果精灵有异常状态,则加上异常状态值
            elif mon.status is not None:
                current_value  = status_value

        # 根据对方队伍中精灵数量与总精灵数量的差值计算当前值
        current_value -= (number_of_pokemons - len(battle.opponent_team)) * hp_value

        # 如果战斗胜利,则加上胜利值
        if battle.won:
            current_value  = victory_value
        # 如果战斗失败,则减去胜利值
        elif battle.lost:
            current_value -= victory_value

        # 计算当前值与奖励缓冲区中的值的差值作为返回值
        to_return = current_value - self._reward_buffer[battle] # the return value is the delta
        self._reward_buffer[battle] = current_value

        return to_return

    def choose_max_damage_move(self, battle: AbstractBattle):
        # 如果有可用的招式,则选择基础威力最大的招式
        if battle.available_moves:
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)
        # 如果没有可用的招式,则随机选择一个招式
        return self.choose_random_move(battle)

.PokeLLMonpoke_envplayerllama_player.py

代码语言:javascript复制
# 导入所需的模块
from poke_env.player.gpt_player import LLMPlayer
from poke_env.environment.abstract_battle import AbstractBattle
import json
from peft import PeftModel
import transformers
import torch
from poke_env.player.player import BattleOrder

# 设置空字符串作为默认的令牌
my_token = ""
# 定义忽略索引
IGNORE_INDEX = -100
# 定义默认的填充令牌
DEFAULT_PAD_TOKEN = "[PAD]"
# 定义默认的结束令牌
DEFAULT_EOS_TOKEN = "</s>"
# 定义默认的开始令牌
DEFAULT_BOS_TOKEN = "<s>"
# 定义默认的未知令牌
DEFAULT_UNK_TOKEN = "<unk>"

# 定义 LLAMAPlayer 类,继承自 LLMPlayer
class LLAMAPlayer(LLMPlayer):
    # 初始化函数,接受多个参数
    def __init__(self, battle_format,
                 model_name_or_path: str = "",
                 # tokenizer_path: str = "",
                 lora_weights: str = "",
                 model_max_length: int = 2048,
                 w_reason = False,
                 log_dir = "",
                 account_configuration=None,
                 server_configuration=None,
                 ):
        # 调用父类的初始化函数
        super().__init__(battle_format=battle_format,
                         account_configuration=account_configuration,
                         server_configuration=server_configuration)

        # 初始化 LLAMA 模型
        # 加载 LLAMA 模型
        self.except_cnt = 0
        self.total_cnt = 0
        self.log_dir = log_dir
        self.w_reason = w_reason
        self.last_output = None
        self.last_state_prompt = None

        # 断言确保模型路径已指定
        assert (model_name_or_path), "Please specify the model path"

        # 使用指定的模型路径加载 tokenizer
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_name_or_path,
            model_max_length=model_max_length,
            padding_side="right",
            use_fast=False,
            use_auth_token=my_token
        )

        # 使用指定的模型路径加载模型
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            device_map="auto",
            use_auth_token=my_token
        )

        # 如果有 LoRA 权重,则加载
        if lora_weights:
            print("Recover LoRA weights..")
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_weights,
                torch_dtype=torch.float16,
            )

        # 输出加载完成信息
        print("Loading finished...")
        # 设置模型为评估模式
        self.model.eval()

.PokeLLMonpoke_envplayeropenai_api.py

代码语言:javascript复制
"""This module defines a player class with the OpenAI API on the main thread.
For a black-box implementation consider using the module env_player.
"""
# 导入必要的模块
from __future__ import annotations

import asyncio
import copy
import random
import time
from abc import ABC, abstractmethod
from logging import Logger
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, Union

# 导入自定义模块
from gymnasium.core import ActType, Env, ObsType
from gymnasium.spaces import Discrete, Space

# 导入自定义模块
from poke_env.concurrency import POKE_LOOP, create_in_poke_loop
from poke_env.environment.abstract_battle import AbstractBattle
from poke_env.player.battle_order import BattleOrder, ForfeitBattleOrder
from poke_env.player.player import Player
from poke_env.ps_client import AccountConfiguration
from poke_env.ps_client.server_configuration import (
    LocalhostServerConfiguration,
    ServerConfiguration,
)
from poke_env.teambuilder.teambuilder import Teambuilder

# 定义一个异步队列类
class _AsyncQueue:
    def __init__(self, queue: asyncio.Queue[Any]):
        self.queue = queue

    # 异步获取队列中的元素
    async def async_get(self):
        return await self.queue.get()

    # 获取队列中的元素
    def get(self):
        res = asyncio.run_coroutine_threadsafe(self.queue.get(), POKE_LOOP)
        return res.result()

    # 异步向队列中放入元素
    async def async_put(self, item: Any):
        await self.queue.put(item)

    # 向队列中放入元素
    def put(self, item: Any):
        task = asyncio.run_coroutine_threadsafe(self.queue.put(item), POKE_LOOP)
        task.result()

    # 判断队列是否为空
    def empty(self):
        return self.queue.empty()

    # 阻塞直到队列中的所有元素都被处理
    def join(self):
        task = asyncio.run_coroutine_threadsafe(self.queue.join(), POKE_LOOP)
        task.result()

    # 异步等待队列中的所有元素都被处理
    async def async_join(self):
        await self.queue.join()

# 定义一个异步玩家类
class _AsyncPlayer(Generic[ObsType, ActType], Player):
    actions: _AsyncQueue
    observations: _AsyncQueue

    def __init__(
        self,
        user_funcs: OpenAIGymEnv[ObsType, ActType],
        username: str,
        **kwargs: Any,
    # 定义一个类,继承自AsyncPlayer类
    ):
        # 设置类名为username
        self.__class__.__name__ = username
        # 调用父类的初始化方法
        super().__init__(**kwargs)
        # 设置类名为"_AsyncPlayer"
        self.__class__.__name__ = "_AsyncPlayer"
        # 初始化observations为一个异步队列
        self.observations = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
        # 初始化actions为一个异步队列
        self.actions = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1))
        # 初始化current_battle为None
        self.current_battle: Optional[AbstractBattle] = None
        # 初始化_user_funcs为user_funcs

    # 定义一个方法,用于选择移动
    def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]:
        # 返回_env_move方法的结果
        return self._env_move(battle)

    # 定义一个异步方法,用于处理环境移动
    async def _env_move(self, battle: AbstractBattle) -> BattleOrder:
        # 如果当前战斗为空或已结束,则将当前战斗设置为传入的战斗
        if not self.current_battle or self.current_battle.finished:
            self.current_battle = battle
        # 如果当前战斗不等于传入的战斗,则抛出异常
        if not self.current_battle == battle:
            raise RuntimeError("Using different battles for queues")
        # 将战斗嵌入到用户函数中,并异步放入observations队列中
        battle_to_send = self._user_funcs.embed_battle(battle)
        await self.observations.async_put(battle_to_send)
        # 从actions队列中异步获取动作
        action = await self.actions.async_get()
        # 如果动作为-1,则返回放弃战斗的指令
        if action == -1:
            return ForfeitBattleOrder()
        # 将动作转换为移动指令并返回
        return self._user_funcs.action_to_move(action, battle)

    # 定义一个回调方法,用于处理战斗结束时的操作
    def _battle_finished_callback(self, battle: AbstractBattle):
        # 将战斗嵌入到用户函数中,并异步放入observations队列中
        to_put = self._user_funcs.embed_battle(battle)
        # 在POKE_LOOP中安全地运行异步放入操作
        asyncio.run_coroutine_threadsafe(self.observations.async_put(to_put), POKE_LOOP)
# 定义一个元类,继承自 ABC 类型
class _ABCMetaclass(type(ABC)):
    pass

# 定义一个元类,继承自 Env 类型
class _EnvMetaclass(type(Env)):
    pass

# 定义一个元类,继承自 _EnvMetaclass 和 _ABCMetaclass
class _OpenAIGymEnvMetaclass(_EnvMetaclass, _ABCMetaclass):
    pass

# 定义一个类 OpenAIGymEnv,继承自 Env[ObsType, ActType] 和 ABC 类型,使用 _OpenAIGymEnvMetaclass 元类
class OpenAIGymEnv(
    Env[ObsType, ActType],
    ABC,
    metaclass=_OpenAIGymEnvMetaclass,
):
    """
    Base class implementing the OpenAI Gym API on the main thread.
    """

    # 初始化重试次数
    _INIT_RETRIES = 100
    # 重试之间的时间间隔
    _TIME_BETWEEN_RETRIES = 0.5
    # 切换挑战任务的重试次数
    _SWITCH_CHALLENGE_TASK_RETRIES = 30
    # 切换重试之间的时间间隔
    _TIME_BETWEEN_SWITCH_RETIRES = 1

    # 初始化方法
    def __init__(
        self,
        account_configuration: Optional[AccountConfiguration] = None,
        *,
        avatar: Optional[int] = None,
        battle_format: str = "gen8randombattle",
        log_level: Optional[int] = None,
        save_replays: Union[bool, str] = False,
        server_configuration: Optional[
            ServerConfiguration
        ] = LocalhostServerConfiguration,
        start_timer_on_battle_start: bool = False,
        start_listening: bool = True,
        ping_interval: Optional[float] = 20.0,
        ping_timeout: Optional[float] = 20.0,
        team: Optional[Union[str, Teambuilder]] = None,
        start_challenging: bool = False,
    # 抽象方法,计算奖励
    @abstractmethod
    def calc_reward(
        self, last_battle: AbstractBattle, current_battle: AbstractBattle
    ) -> float:
        """
        Returns the reward for the current battle state. The battle state in the previous
        turn is given as well and can be used for comparisons.

        :param last_battle: The battle state in the previous turn.
        :type last_battle: AbstractBattle
        :param current_battle: The current battle state.
        :type current_battle: AbstractBattle

        :return: The reward for current_battle.
        :rtype: float
        """
        pass

    # 抽象方法
    @abstractmethod
    # 根据给定的动作和当前战斗状态返回相应的战斗指令
    def action_to_move(self, action: int, battle: AbstractBattle) -> BattleOrder:
        """
        Returns the BattleOrder relative to the given action.

        :param action: The action to take.
        :type action: int
        :param battle: The current battle state
        :type battle: AbstractBattle

        :return: The battle order for the given action in context of the current battle.
        :rtype: BattleOrder
        """
        pass

    # 返回当前战斗状态的嵌入,格式与OpenAI gym API兼容
    @abstractmethod
    def embed_battle(self, battle: AbstractBattle) -> ObsType:
        """
        Returns the embedding of the current battle state in a format compatible with
        the OpenAI gym API.

        :param battle: The current battle state.
        :type battle: AbstractBattle

        :return: The embedding of the current battle state.
        """
        pass

    # 返回嵌入的描述,必须返回一个指定了下限和上限的Space
    @abstractmethod
    def describe_embedding(self) -> Space[ObsType]:
        """
        Returns the description of the embedding. It must return a Space specifying
        low bounds and high bounds.

        :return: The description of the embedding.
        :rtype: Space
        """
        pass

    # 返回动作空间的大小,如果大小为x,则动作空间从0到x-1
    @abstractmethod
    def action_space_size(self) -> int:
        """
        Returns the size of the action space. Given size x, the action space goes
        from 0 to x - 1.

        :return: The action space size.
        :rtype: int
        """
        pass

    # 返回将在挑战循环的下一次迭代中挑战的对手(或对手列表)
    # 如果返回一个列表,则在挑战循环期间将随机选择一个元素
    @abstractmethod
    def get_opponent(
        self,
    ) -> Union[Player, str, List[Player], List[str]]:
        """
        Returns the opponent (or list of opponents) that will be challenged
        on the next iteration of the challenge loop. If a list is returned,
        a random element will be chosen at random during the challenge loop.

        :return: The opponent (or list of opponents).
        :rtype: Player or str or list(Player) or list(str)
        """
        pass
    # 获取对手玩家或字符串
    def _get_opponent(self) -> Union[Player, str]:
        # 获取对手
        opponent = self.get_opponent()
        # 如果对手是列表,则随机选择一个对手,否则直接返回对手
        random_opponent = (
            random.choice(opponent) if isinstance(opponent, list) else opponent
        )
        return random_opponent

    # 重置环境
    def reset(
        self,
        *,
        seed: Optional[int] = None,
        options: Optional[Dict[str, Any]] = None,
    ) -> Tuple[ObsType, Dict[str, Any]]:
        # 如果有种子值,则使用种子值重置环境
        if seed is not None:
            super().reset(seed=seed)  # type: ignore
            self._seed_initialized = True
        # 如果种子值未初始化,则使用当前时间戳作为种子值
        elif not self._seed_initialized:
            super().reset(seed=int(time.time()))  # type: ignore
            self._seed_initialized = True
        # 如果当前没有对战,则等待对战开始
        if not self.agent.current_battle:
            count = self._INIT_RETRIES
            while not self.agent.current_battle:
                if count == 0:
                    raise RuntimeError("Agent is not challenging")
                count -= 1
                time.sleep(self._TIME_BETWEEN_RETRIES)
        # 如果当前对战未结束,则等待对战结束
        if self.current_battle and not self.current_battle.finished:
            if self.current_battle == self.agent.current_battle:
                self._actions.put(-1)
                self._observations.get()
            else:
                raise RuntimeError(
                    "Environment and agent aren't synchronized. Try to restart"
                )
        # 等待当前对战与对手对战不同
        while self.current_battle == self.agent.current_battle:
            time.sleep(0.01)
        # 更新当前对战为对手对战
        self.current_battle = self.agent.current_battle
        battle = copy.copy(self.current_battle)
        battle.logger = None
        self.last_battle = copy.deepcopy(battle)
        return self._observations.get(), self.get_additional_info()

    # 获取额外信息
    def get_additional_info(self) -> Dict[str, Any]:
        """
        Returns additional info for the reset method.
        Override only if you really need it.

        :return: Additional information as a Dict
        :rtype: Dict
        """
        return {}
    def step(
        self, action: ActType
    ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
        """
        Execute the specified action in the environment.

        :param ActType action: The action to be executed.
        :return: A tuple containing the new observation, reward, termination flag, truncation flag, and info dictionary.
        :rtype: Tuple[ObsType, float, bool, bool, Dict[str, Any]]
        """
        # 如果当前战斗为空,则重置环境并返回初始观察和信息
        if not self.current_battle:
            obs, info = self.reset()
            return obs, 0.0, False, False, info
        # 如果当前战斗已经结束,则抛出异常
        if self.current_battle.finished:
            raise RuntimeError("Battle is already finished, call reset")
        # 复制当前战斗对象,以便进行操作
        battle = copy.copy(self.current_battle)
        battle.logger = None
        # 深度复制当前战斗对象,用于记录上一次的战斗状态
        self.last_battle = copy.deepcopy(battle)
        # 将动作放入动作队列
        self._actions.put(action)
        # 从观察队列中获取观察结果
        observation = self._observations.get()
        # 计算奖励
        reward = self.calc_reward(self.last_battle, self.current_battle)
        terminated = False
        truncated = False
        # 如果当前战斗已经结束
        if self.current_battle.finished:
            size = self.current_battle.team_size
            # 计算剩余队伍中未被击倒的精灵数量
            remaining_mons = size - len(
                [mon for mon in self.current_battle.team.values() if mon.fainted]
            )
            remaining_opponent_mons = size - len(
                [
                    mon
                    for mon in self.current_battle.opponent_team.values()
                    if mon.fainted
                ]
            )
            # 如果一方队伍的精灵全部被击倒,则游戏结束
            if (remaining_mons == 0) != (remaining_opponent_mons == 0):
                terminated = True
            else:
                truncated = True
        # 返回观察结果、奖励、游戏是否结束、游戏是否截断以及额外信息
        return observation, reward, terminated, truncated, self.get_additional_info()
    # 渲染当前战斗状态,显示当前回合信息和双方精灵状态
    def render(self, mode: str = "human"):
        # 如果当前存在战斗
        if self.current_battle is not None:
            # 打印当前回合信息和双方精灵状态
            print(
                "  Turn M. | [%s][=/=hp] .10s - .10s [=%%hp][%s]"
                % (
                    self.current_battle.turn,
                    "".join(
                        [
                            "⦻" if mon.fainted else "●"
                            for mon in self.current_battle.team.values()
                        ]
                    ),
                    self.current_battle.active_pokemon.current_hp or 0,
                    self.current_battle.active_pokemon.max_hp or 0,
                    self.current_battle.active_pokemon.species,
                    self.current_battle.opponent_active_pokemon.species,
                    self.current_battle.opponent_active_pokemon.current_hp or 0,
                    "".join(
                        [
                            "⦻" if mon.fainted else "●"
                            for mon in self.current_battle.opponent_team.values()
                        ]
                    ),
                ),
                end="n" if self.current_battle.finished else "r",
            )

    # 关闭当前战斗,清理资源
    def close(self, purge: bool = True):
        # 如果当前没有战斗或者当前战斗已结束
        if self.current_battle is None or self.current_battle.finished:
            # 等待1秒
            time.sleep(1)
            # 如果当前战斗不是代理的当前战斗
            if self.current_battle != self.agent.current_battle:
                self.current_battle = self.agent.current_battle
        # 创建一个异步任务来停止挑战循环
        closing_task = asyncio.run_coroutine_threadsafe(
            self._stop_challenge_loop(purge=purge), POKE_LOOP
        )
        # 获取异步任务的结果
        closing_task.result()
    def background_send_challenge(self, username: str):
        """
        Sends a single challenge to a specified player asynchronously. The function immediately returns
        to allow use of the OpenAI gym API.

        :param username: The username of the player to challenge.
        :type username: str
        """
        # 检查是否已经有挑战任务在进行,如果有则抛出异常
        if self._challenge_task and not self._challenge_task.done():
            raise RuntimeError(
                "Agent is already challenging opponents with the challenging loop. "
                "Try to specify 'start_challenging=True' during instantiation or call "
                "'await agent.stop_challenge_loop()' to clear the task."
            )
        # 在另一个线程中异步运行发送挑战的方法
        self._challenge_task = asyncio.run_coroutine_threadsafe(
            self.agent.send_challenges(username, 1), POKE_LOOP
        )

    def background_accept_challenge(self, username: str):
        """
        Accepts a single challenge from a specified player asynchronously. The function immediately returns
        to allow use of the OpenAI gym API.

        :param username: The username of the player to challenge.
        :type username: str
        """
        # 检查是否已经有挑战任务在进行,如果有则抛出异常
        if self._challenge_task and not self._challenge_task.done():
            raise RuntimeError(
                "Agent is already challenging opponents with the challenging loop. "
                "Try to specify 'start_challenging=True' during instantiation or call "
                "'await agent.stop_challenge_loop()' to clear the task."
            )
        # 在另一个线程中异步运行接受挑战的方法
        self._challenge_task = asyncio.run_coroutine_threadsafe(
            self.agent.accept_challenges(username, 1, self.agent.next_team), POKE_LOOP
        )

    async def _challenge_loop(
        self,
        n_challenges: Optional[int] = None,
        callback: Optional[Callable[[AbstractBattle], None]] = None,
    # 如果没有指定挑战次数,则持续挑战直到 self._keep_challenging 为 False
    ):
        # 如果没有挑战次数且 self._keep_challenging 为 True
        if not n_challenges:
            # 持续挑战直到 self._keep_challenging 为 False
            while self._keep_challenging:
                # 获取对手
                opponent = self._get_opponent()
                # 如果对手是 Player 类型
                if isinstance(opponent, Player):
                    # 进行一场对战
                    await self.agent.battle_against(opponent, 1)
                else:
                    # 发送挑战请求
                    await self.agent.send_challenges(opponent, 1)
                # 如果有回调函数且当前对战不为 None
                if callback and self.current_battle is not None:
                    # 复制当前对战并调用回调函数
                    callback(copy.deepcopy(self.current_battle))
        # 如果指定了挑战次数且挑战次数大于 0
        elif n_challenges > 0:
            # 循环指定次数
            for _ in range(n_challenges):
                # 获取对手
                opponent = self._get_opponent()
                # 如果对手是 Player 类型
                if isinstance(opponent, Player):
                    # 进行一场对战
                    await self.agent.battle_against(opponent, 1)
                else:
                    # 发送挑战请求
                    await self.agent.send_challenges(opponent, 1)
                # 如果有回调函数且当前对战不为 None
                if callback and self.current_battle is not None:
                    # 复制当前对战并调用回调函数
                    callback(copy.deepcopy(self.current_battle))
        # 如果挑战次数小于等于 0
        else:
            # 抛出数值错误异常
            raise ValueError(f"Number of challenges must be > 0. Got {n_challenges}")

    # 开始挑战
    def start_challenging(
        # 指定挑战次数,默认为 None
        self,
        n_challenges: Optional[int] = None,
        # 回调函数,接受 AbstractBattle 类型参数并返回 None
        callback: Optional[Callable[[AbstractBattle], None]] = None,
    ):
        """
        Starts the challenge loop.

        :param n_challenges: The number of challenges to send. If empty it will run until
            stopped.
        :type n_challenges: int, optional
        :param callback: The function to callback after each challenge with a copy of
            the final battle state.
        :type callback: Callable[[AbstractBattle], None], optional
        """
        # 检查是否存在正在进行的挑战任务,如果有则等待直到完成
        if self._challenge_task and not self._challenge_task.done():
            count = self._SWITCH_CHALLENGE_TASK_RETRIES
            while not self._challenge_task.done():
                if count == 0:
                    raise RuntimeError("Agent is already challenging")
                count -= 1
                time.sleep(self._TIME_BETWEEN_SWITCH_RETIRES)
        # 如果没有指定挑战次数,则设置为持续挑战
        if not n_challenges:
            self._keep_challenging = True
        # 启动挑战循环任务
        self._challenge_task = asyncio.run_coroutine_threadsafe(
            self._challenge_loop(n_challenges, callback), POKE_LOOP
        )

    async def _ladder_loop(
        self,
        n_challenges: Optional[int] = None,
        callback: Optional[Callable[[AbstractBattle], None]] = None,
    ):
        # 如果指定了挑战次数,则进行相应次数的挑战
        if n_challenges:
            if n_challenges <= 0:
                raise ValueError(
                    f"Number of challenges must be > 0. Got {n_challenges}"
                )
            for _ in range(n_challenges):
                await self.agent.ladder(1)
                # 如果有回调函数且当前战斗状态不为空,则执行回调函数
                if callback and self.current_battle is not None:
                    callback(copy.deepcopy(self.current_battle))
        # 如果未指定挑战次数,则持续挑战直到停止
        else:
            while self._keep_challenging:
                await self.agent.ladder(1)
                # 如果有回调函数且当前战斗状态不为空,则执行回调函数
                if callback and self.current_battle is not None:
                    callback(copy.deepcopy(self.current_battle))

    # 启动 ladder 循环挑战
    def start_laddering(
        self,
        n_challenges: Optional[int] = None,
        callback: Optional[Callable[[AbstractBattle], None]] = None,
    ):
        """
        Starts the laddering loop.

        :param n_challenges: The number of ladder games to play. If empty it
            will run until stopped.
        :type n_challenges: int, optional
        :param callback: The function to callback after each challenge with a
            copy of the final battle state.
        :type callback: Callable[[AbstractBattle], None], optional
        """
        # 检查是否存在正在进行的挑战任务,如果有则等待直到完成
        if self._challenge_task and not self._challenge_task.done():
            count = self._SWITCH_CHALLENGE_TASK_RETRIES
            while not self._challenge_task.done():
                if count == 0:
                    raise RuntimeError("Agent is already challenging")
                count -= 1
                time.sleep(self._TIME_BETWEEN_SWITCH_RETIRES)
        # 如果没有指定挑战次数,则设置为持续挑战
        if not n_challenges:
            self._keep_challenging = True
        # 使用 asyncio 在另一个线程中运行 _ladder_loop 方法,传入挑战次数和回调函数
        self._challenge_task = asyncio.run_coroutine_threadsafe(
            self._ladder_loop(n_challenges, callback), POKE_LOOP
        )

    async def _stop_challenge_loop(
        self, force: bool = True, wait: bool = True, purge: bool = False
    ):  # 定义一个方法,接受多个参数
        self._keep_challenging = False  # 将属性_keep_challenging设置为False

        if force:  # 如果force为真
            if self.current_battle and not self.current_battle.finished:  # 如果存在当前战斗且未结束
                if not self._actions.empty():  # 如果_actions队列不为空
                    await asyncio.sleep(2)  # 异步等待2秒
                    if not self._actions.empty():  # 如果_actions队列仍不为空
                        raise RuntimeError(  # 抛出运行时错误
                            "The agent is still sending actions. "
                            "Use this method only when training or "
                            "evaluation are over."
                        )
                if not self._observations.empty():  # 如果_observations队列不为空
                    await self._observations.async_get()  # 异步获取_observations队列中的数据
                await self._actions.async_put(-1)  # 异步将-1放入_actions队列中

        if wait and self._challenge_task:  # 如果wait为真且_challenge_task存在
            while not self._challenge_task.done():  # 当_challenge_task未完成时
                await asyncio.sleep(1)  # 异步等待1秒
            self._challenge_task.result()  # 获取_challenge_task的结果

        self._challenge_task = None  # 将_challenge_task设置为None
        self.current_battle = None  # 将current_battle设置为None
        self.agent.current_battle = None  # 将agent的current_battle设置为None
        while not self._actions.empty():  # 当_actions队列不为空时
            await self._actions.async_get()  # 异步获取_actions队列中的数据
        while not self._observations.empty():  # 当_observations队列不为空时
            await self._observations.async_get()  # 异步获取_observations队列中的数据

        if purge:  # 如果purge为真
            self.agent.reset_battles()  # 调用agent的reset_battles方法

    def reset_battles(self):  # 定义一个方法reset_battles
        """Resets the player's inner battle tracker."""  # 重置玩家的内部战斗追踪器
        self.agent.reset_battles()  # 调用agent的reset_battles方法
    # 检查任务是否完成,可设置超时时间
    def done(self, timeout: Optional[int] = None) -> bool:
        """
        Returns True if the task is done or is done after the timeout, false otherwise.

        :param timeout: The amount of time to wait for if the task is not already done.
            If empty it will wait until the task is done.
        :type timeout: int, optional

        :return: True if the task is done or if the task gets completed after the
            timeout.
        :rtype: bool
        """
        # 如果挑战任务为空,则返回True
        if self._challenge_task is None:
            return True
        # 如果超时时间为空,则等待任务完成
        if timeout is None:
            self._challenge_task.result()
            return True
        # 如果挑战任务已完成,则返回True
        if self._challenge_task.done():
            return True
        # 等待一段时间后再次检查任务是否完成
        time.sleep(timeout)
        return self._challenge_task.done()

    # 暴露Player类的属性

    @property
    def battles(self) -> Dict[str, AbstractBattle]:
        return self.agent.battles

    @property
    def format(self) -> str:
        return self.agent.format

    @property
    def format_is_doubles(self) -> bool:
        return self.agent.format_is_doubles

    @property
    def n_finished_battles(self) -> int:
        return self.agent.n_finished_battles

    @property
    def n_lost_battles(self) -> int:
        return self.agent.n_lost_battles

    @property
    def n_tied_battles(self) -> int:
        return self.agent.n_tied_battles

    @property
    def n_won_battles(self) -> int:
        return self.agent.n_won_battles

    @property
    def win_rate(self) -> float:
        return self.agent.win_rate

    # 暴露Player Network Interface Class的属性

    @property
    def logged_in(self) -> asyncio.Event:
        """Event object associated with user login.

        :return: The logged-in event
        :rtype: Event
        """
        return self.agent.ps_client.logged_in

    @property
    # 返回与玩家相关联的日志记录器
    def logger(self) -> Logger:
        """Logger associated with the player.

        :return: The logger.
        :rtype: Logger
        """
        return self.agent.logger

    # 返回玩家的用户名
    @property
    def username(self) -> str:
        """The player's username.

        :return: The player's username.
        :rtype: str
        """
        return self.agent.username

    # 返回 WebSocket 的 URL
    @property
    def websocket_url(self) -> str:
        """The websocket url.

        It is derived from the server url.

        :return: The websocket url.
        :rtype: str
        """
        return self.agent.ps_client.websocket_url

    # 获取属性的值
    def __getattr__(self, item: str):
        return getattr(self.agent, item)

0 人点赞