集成指南
Aim与您喜爱的ML框架无缝集成 - Pytorch Ignite、Pytorch Lightning、Hugging Face等。 基础集成指南可在快速入门部分找到。
在本节中,我们将深入探讨如何扩展基础日志记录器,通过调整它们来追踪更多内容。基础日志记录器仅能追踪特定指标和超参数。
Aim回调/适配器/记录器有两种扩展方式:
通过继承并重写负责日志记录的主要方法。
通过使用名为
experiment的公共属性,可以访问底层的aim.Run对象,从而轻松跟踪对项目有益的新指标、参数和其他元数据。
PyTorch Ignite
两种回调扩展机制均可与Pytorch Ignite配合使用。
在以下示例中,您将了解如何利用实验属性,在训练完成后通过aim.Image来跟踪混淆矩阵图像。
这里有一个示例colab笔记本。
from aim import Image
from aim.pytorch_ignite import AimLogger
import matplotlib.pyplot as plt
import seaborn as sns
# Create a logger
aim_logger = AimLogger()
...
@trainer.on(Events.COMPLETED)
def log_confusion_matrix(trainer):
metrics = val_evaluator.state.metrics
cm = metrics['cm']
cm = cm.numpy()
cm = cm.astype(int)
classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
fig, ax = plt.subplots(figsize=(10,10))
ax= plt.subplot()
sns.heatmap(cm, annot=True, ax = ax,fmt="d")
# labels, title and ticks
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(classes,rotation=90)
ax.yaxis.set_ticklabels(classes,rotation=0)
aim_logger.experiment.track(Image(fig), name='cm_training_end')
使用Pytorch Ignite还有第三种扩展集成的方法。 例如,Pytorch Ignite的Tensorboard记录器提供了跟踪模型梯度和权重直方图的功能。 使用Aim也可以实现同样的效果
from typing import Optional, Union
import torch.nn as nn
from ignite.contrib.handlers.base_logger import BaseWeightsHistHandler
from ignite.engine import Engine, Events
from aim.pytorch_ignite import AimLogger
from aim import Distribution
class AimGradsHistHandler(BaseWeightsHistHandler):
def __init__(self, model: nn.Module, tag: Optional[str] = None):
super(GradsHistHandler, self).__init__(model, tag=tag)
def __call__(self, engine: Engine, logger: AimLogger, event_name: Union[str, Events]) -> None:
global_step = engine.state.get_event_attrib_value(event_name)
context = {'subset': self.tag} if self.tag else {}
for name, p in self.model.named_parameters():
if p.grad is None:
continue
name = name.replace(".", "/")
logger.experiment.track(
Distribution(p.grad.detach().cpu().numpy()),
name=name,
step=global_step,
context=context
)
# Create a logger
aim_logger = AimLogger()
# Attach the logger to the trainer to log model's weights norm after each iteration
aim_logger.attach(
trainer,
event_name=Events.ITERATION_COMPLETED,
log_handler=AimGradsHistHandler(model)
)
Pytorch Lightning
在Aim GitHub仓库提供的示例中,使用PL + Aim已经展示了如何自定义集成的参考方法。
def test_step(self, batch, batch_idx):
...
# Track metrics manually
self.logger.experiment.track(1, name='manually_tracked_metric')
因此,您可以在每个测试步骤的迭代中跟踪大量元数据:图像、文本,以及您需要的任何且Aim支持的内容。
Hugging Face
以下是扩展基础Hugging Face记录器的方法。
下面是一个从AimCallback派生的CustomCallback示例。这里主要的HF方法是重写的on_log()。
这使我们能够追踪任何传递给on_log()方法的str对象,并将其记录为aim.Text。
from aim.hugging_face import AimCallback
from aim import Text
class CustomCallback(AimCallback):
def on_log(self, args, state, control,
model=None, logs=None, **kwargs):
super().on_log(args, state, control, model, logs, **kwargs)
context = {
'subset': self._current_shift,
}
for log_name, log_value in logs.items():
if isinstance(log_value, str):
self.experiment.track(Text(log_value), name=log_name, context=context)
TF/keras
以下是如何在使用Aim跟踪混淆矩阵的同时扩展tf.keras提供的默认回调。
我们采用并调整了这个示例以适应Aim。效果如下:
from aim.tensorflow import AimCallback
class CustomImageTrackingCallback(AimCallback):
def __init__(self, data):
super().__init__()
self.data = data
def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
from aim import Image
# Use the model to predict the values from the validation dataset.
test_pred_raw = self.model.predict(test_images)
test_pred = np.argmax(test_pred_raw, axis=1)
# Calculate the confusion matrix.
cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
# Log the confusion matrix as an image summary.
figure = plot_confusion_matrix(cm, class_names=class_names)
cm_image = Image(figure)
# Log the confusion matrix as an Aim image.
self.experiment.track(cm_image,"Confusion Matrix", step=epoch)
aim_callback = CustomImageTrackingCallback()
model.fit(
train_images,
train_labels,
epochs=5,
verbose=0, # Suppress chatty output
callbacks=[aim_callback],
validation_data=(test_images, test_labels),
)
XGBoost
以下是针对XGBoost重写AimCallback的方法。
from aim import Text
from aim.xgboost import AimCallback
class CustomCallback(AimCallback):
def after_iteration(self, model, epoch, evals_log):
for data, metric in evals_log.items():
for metric_name, log in metric.items():
self.experiment.track(Text(log), name=metric_name)
return super().after_iteration(model, epoch, evals_log)
Catboost
Catboost的.fit()方法具有log_cout参数,可用于将日志输出重定向到自定义对象中,该对象需包含write属性。我们的日志记录器是一个实现了write方法的对象,该方法会根据日志内容解析日志字符串。因此,大部分日志输出将被我们的解析逻辑忽略,但您仍可以在我们的基础上编写自己的逻辑来满足您的特定需求。
from aim.catboost import AimLogger
class CustomLogger(AimLogger):
def write(self, log):
# Process the log string through our parser
super().write(log)
# Do your own parsing
log = log.strip().split()
if log[1] == 'bin:':
value_bin = log[1][4:]
value_score = self._to_number(log[3])
self.experiment.track(value_score, name='score')
LightGBM
以下是针对LightGBM覆盖AimCallback的方法。
from aim.lightgbm import AimCallback
class CustomCallback(AimCallback):
def before_tracking(self, env):
for item in env.evaluation_result_list:
# manipulate item here
pass
def after_tracking(self, env):
# do any other action if necessary after tracking value
pass