集成指南

Aim与您喜爱的ML框架无缝集成 - Pytorch Ignite、Pytorch Lightning、Hugging Face等。 基础集成指南可在快速入门部分找到。

在本节中,我们将深入探讨如何扩展基础日志记录器,通过调整它们来追踪更多内容。基础日志记录器仅能追踪特定指标和超参数。

Aim回调/适配器/记录器有两种扩展方式:

  • 通过继承并重写负责日志记录的主要方法。

  • 通过使用名为experiment的公共属性,可以访问底层的aim.Run对象,从而轻松跟踪对项目有益的新指标、参数和其他元数据。

PyTorch Ignite

两种回调扩展机制均可与Pytorch Ignite配合使用。 在以下示例中,您将了解如何利用实验属性,在训练完成后通过aim.Image来跟踪混淆矩阵图像。

这里有一个示例colab笔记本

from aim import Image
from aim.pytorch_ignite import AimLogger

import matplotlib.pyplot as plt
import seaborn as sns

# Create a logger
aim_logger = AimLogger()
...
@trainer.on(Events.COMPLETED)
def log_confusion_matrix(trainer):
    metrics = val_evaluator.state.metrics
    cm = metrics['cm']
    cm = cm.numpy()
    cm = cm.astype(int)
    classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
    fig, ax = plt.subplots(figsize=(10,10))
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, ax = ax,fmt="d")
    # labels, title and ticks
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels')
    ax.set_title('Confusion Matrix')
    ax.xaxis.set_ticklabels(classes,rotation=90)
    ax.yaxis.set_ticklabels(classes,rotation=0)
    aim_logger.experiment.track(Image(fig), name='cm_training_end')

使用Pytorch Ignite还有第三种扩展集成的方法。 例如,Pytorch Ignite的Tensorboard记录器提供了跟踪模型梯度和权重直方图的功能。 使用Aim也可以实现同样的效果

from typing import Optional, Union

import torch.nn as nn
from ignite.contrib.handlers.base_logger import BaseWeightsHistHandler
from ignite.engine import Engine, Events

from aim.pytorch_ignite import AimLogger
from aim import Distribution


class AimGradsHistHandler(BaseWeightsHistHandler):
    def __init__(self, model: nn.Module, tag: Optional[str] = None):
        super(GradsHistHandler, self).__init__(model, tag=tag)

    def __call__(self, engine: Engine, logger: AimLogger, event_name: Union[str, Events]) -> None:
        global_step = engine.state.get_event_attrib_value(event_name)
        context = {'subset': self.tag} if self.tag else {}
        for name, p in self.model.named_parameters():
            if p.grad is None:
                continue
            name = name.replace(".", "/")
            logger.experiment.track(
                Distribution(p.grad.detach().cpu().numpy()),
                name=name,
                step=global_step,
                context=context
            )

# Create a logger
aim_logger = AimLogger()

# Attach the logger to the trainer to log model's weights norm after each iteration
aim_logger.attach(
    trainer,
    event_name=Events.ITERATION_COMPLETED,
    log_handler=AimGradsHistHandler(model)
)

Pytorch Lightning

在Aim GitHub仓库提供的示例中,使用PL + Aim已经展示了如何自定义集成的参考方法。

def test_step(self, batch, batch_idx):
    ...
    # Track metrics manually
    self.logger.experiment.track(1, name='manually_tracked_metric')

因此,您可以在每个测试步骤的迭代中跟踪大量元数据:图像、文本,以及您需要的任何且Aim支持的内容。

Hugging Face

以下是扩展基础Hugging Face记录器的方法。 下面是一个从AimCallback派生的CustomCallback示例。这里主要的HF方法是重写的on_log()

这使我们能够追踪任何传递给on_log()方法的str对象,并将其记录为aim.Text

from aim.hugging_face import AimCallback
from aim import Text


class CustomCallback(AimCallback):
    def on_log(self, args, state, control,
               model=None, logs=None, **kwargs):
        super().on_log(args, state, control, model, logs, **kwargs)

        context = {
            'subset': self._current_shift,
        }
        for log_name, log_value in logs.items():
            if isinstance(log_value, str):
                self.experiment.track(Text(log_value), name=log_name, context=context)

TF/keras

以下是如何在使用Aim跟踪混淆矩阵的同时扩展tf.keras提供的默认回调。 我们采用并调整了这个示例以适应Aim。效果如下:

from aim.tensorflow import AimCallback

class CustomImageTrackingCallback(AimCallback):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        from aim import Image
        # Use the model to predict the values from the validation dataset.
        test_pred_raw = self.model.predict(test_images)
        test_pred = np.argmax(test_pred_raw, axis=1)

        # Calculate the confusion matrix.
        cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
        # Log the confusion matrix as an image summary.
        figure = plot_confusion_matrix(cm, class_names=class_names)
        cm_image = Image(figure)

        # Log the confusion matrix as an Aim image.
        self.experiment.track(cm_image,"Confusion Matrix", step=epoch)

aim_callback = CustomImageTrackingCallback()

model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, # Suppress chatty output
    callbacks=[aim_callback],
    validation_data=(test_images, test_labels),
)

XGBoost

以下是针对XGBoost重写AimCallback的方法。

from aim import Text
from aim.xgboost import AimCallback


class CustomCallback(AimCallback):

    def after_iteration(self, model, epoch, evals_log):
        for data, metric in evals_log.items():
            for metric_name, log in metric.items():
                self.experiment.track(Text(log), name=metric_name)

        return super().after_iteration(model, epoch, evals_log)

Catboost

Catboost的.fit()方法具有log_cout参数,可用于将日志输出重定向到自定义对象中,该对象需包含write属性。我们的日志记录器是一个实现了write方法的对象,该方法会根据日志内容解析日志字符串。因此,大部分日志输出将被我们的解析逻辑忽略,但您仍可以在我们的基础上编写自己的逻辑来满足您的特定需求。

from aim.catboost import AimLogger


class CustomLogger(AimLogger):

    def write(self, log):
        # Process the log string through our parser
        super().write(log)

        # Do your own parsing
        log = log.strip().split()
        if log[1] == 'bin:':
            value_bin = log[1][4:]
            value_score = self._to_number(log[3])
            self.experiment.track(value_score, name='score')

LightGBM

以下是针对LightGBM覆盖AimCallback的方法。

from aim.lightgbm import AimCallback


class CustomCallback(AimCallback):

    def before_tracking(self, env):
        for item in env.evaluation_result_list:
            # manipulate item here
            pass

    def after_tracking(self, env):
        # do any other action if necessary after tracking value
        pass