ray.rllib.evaluation.rollout_worker.RolloutWorker.get_weights#
- RolloutWorker.get_weights(policies: Collection[str] | None = None, inference_only: bool = False) Dict[str, dict][源代码]#
返回此工作者的每个策略的模型权重。
- 参数:
policies – 获取权重的 PolicyIDs 列表。使用 None 表示所有策略。
inference_only – 此参数仅为了与新的API堆栈保持接口一致性而添加。
- 返回:
将策略ID映射到模型权重的字典。
from ray.rllib.evaluation.rollout_worker import RolloutWorker # Create a RolloutWorker. worker = ... weights = worker.get_weights() print(weights)
{"default_policy": {"layer1": array(...), "layer2": ...}}