Shortcuts

模型可解释性示例

这是一个使用captum分析模型输入以进行模型可解释性目的的TorchX应用示例。它使用了来自训练器应用示例的训练模型和来自数据预处理应用示例的预处理数据。输出是一系列带有集成梯度归因叠加的图像。

有关使用captum的更多信息,请参见https://captum.ai/tutorials/CIFAR_TorchVision_Interpret

用法

在本地将此主模块作为Python进程运行。下面的运行假设模型已经按照torchx/examples/apps/lightning/train.py中的使用说明进行了训练。

$ torchx run -s local_cwd utils.python
    --script ./lightning/interpret.py
    --
    --load_path /tmp/torchx/train/last.ckpt
    --output_path /tmp/torchx/interpret

使用图像查看器来可视化在output_path下生成的*.png文件。

注意

对于使用TorchX的utils.python内置功能的本地运行,实际上等同于直接运行主模块(例如python ./interpret.py)。使用TorchX启动简单的单进程Python程序的好处是可以通过将-s local_cwd替换为远程调度器(如kubernetes)来在远程调度器上启动,只需指定-s kubernetes即可。

import argparse
import itertools
import os.path
import sys
import tempfile
from typing import List

import fsspec
import torch
from torchx.examples.apps.lightning.data import (
    create_random_data,
    download_data,
    TinyImageNetDataModule,
)
from torchx.examples.apps.lightning.model import TinyImageNetModel


# ensure data and module are on the path
sys.path.append(".")


# FIXME: captum must be imported after torch otherwise it causes python to crash
if True:
    import numpy as np
    from captum.attr import IntegratedGradients, visualization as viz


def parse_args(argv: List[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="example TorchX captum app")
    parser.add_argument(
        "--load_path",
        type=str,
        help="checkpoint path to load model weights from",
        required=True,
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="path to load the training data from, if not provided, random dataset will be created",
    )
    parser.add_argument(
        "--output_path",
        type=str,
        help="path to place analysis results",
        required=True,
    )

    return parser.parse_args(argv)


def convert_to_rgb(arr: torch.Tensor) -> np.ndarray:  # pyre-ignore[24]
    """
    This converts the image from a torch tensor with size (1, 1, 64, 64) to
    numpy array with size (64, 64, 3).
    """
    out = arr.squeeze().swapaxes(0, 2)
    assert out.shape == (64, 64, 3), "invalid shape produced"
    return out.numpy()


def main(argv: List[str]) -> None:
    with tempfile.TemporaryDirectory() as tmpdir:
        args = parse_args(argv)

        # Init our model
        model = TinyImageNetModel()

        print(f"loading checkpoint: {args.load_path}...")
        model.load_from_checkpoint(checkpoint_path=args.load_path)

        # Download and setup the data module
        if not args.data_path:
            data_path = os.path.join(tmpdir, "data")
            os.makedirs(data_path)
            create_random_data(data_path)
        else:
            data_path = download_data(args.data_path, tmpdir)
        data = TinyImageNetDataModule(
            data_dir=data_path,
            batch_size=1,
        )

        ig = IntegratedGradients(model)

        data.setup("test")
        dataloader = data.test_dataloader()

        # process first 5 images
        for i, (input, label) in enumerate(itertools.islice(dataloader, 5)):
            print(f"analyzing example {i}")
            # input = input.unsqueeze(dim=0)
            model.zero_grad()
            attr_ig, delta = ig.attribute(
                input,
                target=label,
                baselines=input * 0,
                return_convergence_delta=True,
            )

            if attr_ig.count_nonzero() == 0:
                # Our toy model sometimes has no IG results.
                print("skipping due to zero gradients")
                continue

            fig, axis = viz.visualize_image_attr(
                convert_to_rgb(attr_ig),
                convert_to_rgb(input),
                method="blended_heat_map",
                sign="all",
                show_colorbar=True,
                title="Overlayed Integrated Gradients",
            )
            out_path = os.path.join(args.output_path, f"ig_{i}.png")
            print(f"saving heatmap to {out_path}")
            with fsspec.open(out_path, "wb") as f:
                fig.savefig(f)


if __name__ == "__main__" and "NOTEBOOK" not in globals():
    main(sys.argv[1:])


# sphinx_gallery_thumbnail_path = '_static/img/gallery-app.png'

脚本总运行时间: ( 0 分钟 0.000 秒)

Gallery generated by Sphinx-Gallery