ray.rllib.evaluation.rollout_worker.RolloutWorker.get_weights#

RolloutWorker.get_weights(policies: Collection[str] | None = None, inference_only: bool = False) Dict[str, dict][源代码]#

返回此工作者的每个策略的模型权重。

参数:
  • policies – 获取权重的 PolicyIDs 列表。使用 None 表示所有策略。

  • inference_only – 此参数仅为了与新的API堆栈保持接口一致性而添加。

返回:

将策略ID映射到模型权重的字典。

from ray.rllib.evaluation.rollout_worker import RolloutWorker
# Create a RolloutWorker.
worker = ...
weights = worker.get_weights()
print(weights)
{"default_policy": {"layer1": array(...), "layer2": ...}}