Shortcuts

追踪

概述与用法

注意

实验性功能,使用风险自负,API可能会发生变化

在TorchX应用程序中,应用程序是二进制文件(可执行文件),因此没有内置的方法可以从应用程序“返回”结果。torchx.runtime.tracking模块允许应用程序返回简单的结果(注意关键词“简单”)。跟踪器模块支持的返回类型是故意受限的。例如,不允许尝试返回训练好的模型权重,这些权重可能高达数百GB。该模块既不是设计用于也不是优化用于传递大量数据或二进制大对象的。

当应用程序作为更高级别的协调工作(例如工作流、管道、超参数优化)的一部分启动时,通常需要使应用程序的结果对协调器或工作流中的其他应用程序可访问。

假设App1和App2依次启动,App1的输出作为App2的输入。由于这些是二进制文件,应用程序之间链接输入/输出的典型方法是将App1的输出文件路径作为App2的输入文件路径传递:

$ app1 --output-file s3://foo/out/app1.out
$ app2 --input-file s3://foo/out/app1.out

尽管这看起来很简单,但有几件事需要注意:

  1. 文件app1.out的格式(app1需要以app2理解的格式编写)

  2. 实际解析URL并写入/读取输出文件

因此,应用程序的主函数最终看起来像这样(用于演示目的的伪代码):

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   s3client = ...
   out = {"accuracy": accuracy}

   with open("/tmp/out", "w") as f:
       f = json.dumps(out).encode("utf-8")

   s3client.put(args.output_file, f)

# in app2.py
if __name__ == "__main__":
   s3client = ...
   with open("/tmp/out", "w") as f:
       s3client.get(args.input_file, f)

   with open("/tmp/out", "r") as f:
       in = json.loads(f.read().decode("utf-8"))

   do_something_else(in["accuracy"])

相反,使用跟踪器时,可以在应用程序之间使用具有相同tracker_base的跟踪器,使一个应用程序的返回值可用于另一个应用程序,而无需将一个应用程序的输出文件路径与另一个应用程序的输入文件路径链接起来,并处理自定义序列化和文件写入。

# in app1.py
if __name__ == "__main__":
   accuracy = do_something()
   tracker = FsspecResultTracker(args.tracker_base)
   tracker["app1_out"] = {"accuracy": accuracy}

# in app2.py
if __name__ == "__main__":
   tracker = FsspecResultTracker(args.tracker_base)
   app1_accuracy = tracker["app1_out"]
   do_something_else(app1_accuracy)

结果跟踪器

基础

class torchx.runtime.tracking.ResultTracker[source]

基础结果跟踪器,应该被子类化以实现跟踪器。 通常每个支持存储都有一个跟踪器实现。

用法:

# get and put APIs can be used directly or in map-like API
# the following are equivalent
tracker.put("foo", l2norm=1.2)
tracker["foo"] = {"l2norm": 1.2}

# so are these
tracker.get("foo")["l2norm"] == 1.2
tracker["foo"]["l2norm"] == 1.2

有效的 result 类型有:

  1. 数值类型:整数,浮点数

  2. 字面量:str (utf-8编码时1kb大小限制)

有效的key类型有:

  1. int

  2. str

按照惯例,“斜杠”可以用于键中,以存储统计结果。例如,存储l2norm的均值和标准误差:

tracker[key] = {"l2norm/mean" : 1.2, "l2norm/sem": 3.4}
tracker[key]["l2norm/mean"] # returns 1.2
tracker[key]["l2norm/sem"] # returns 3.4

假设键在跟踪器的后备存储范围内是唯一的。例如,如果跟踪器由本地目录支持,并且key是保存结果的目录中的文件,那么

# same key, different backing directory -> results are not overwritten
FsspecResultTracker("/tmp/foo")["1"] = {"l2norm":1.2}
FsspecResultTracker("/tmp/bar")["1"] = {"l2norm":3.4}

跟踪器不是一个中央实体,因此在同一键上的putget操作之间没有强一致性保证(除了底层存储提供的保证)。同样,在同一键上的两个连续的putget操作之间也没有强一致性保证。

例如:

tracker[1] = {"l2norm":1.2}
tracker[1] = {"l2norm":3.4}
tracker[1] # NOT GUARANTEED TO BE 3.4!

sleep(1*MIN)
tracker[1] # more likely to be 3.4 but still not guaranteed!

强烈建议使用唯一ID作为键。对于简单的任务,这个ID通常是任务ID,或者对于像超参数优化这样的迭代应用,可以是(实验ID,试验编号)或(任务ID,副本/工作器等级)的组合。

Fsspec

class torchx.runtime.tracking.FsspecResultTracker(tracker_base: str)[source]

使用fsspec在后台保存结果的跟踪器。

用法:

from torchx.runtime.tracking import FsspecResultTracker

# PUT: in trainer.py
tracker_base = "/tmp/foobar" # also supports URIs (e.g. "s3://bucket/trainer/123")
tracker = FsspecResultTracker(tracker_base)
tracker["attempt_1/out"] = {"accuracy": 0.233}

# GET: anywhere outside trainer.py
tracker = FsspecResultTracker(tracker_base)
print(tracker["attempt_1/out"]["accuracy"])
0.233