BERT sklearn

为了将bert-sklearn与HiClass一起使用,需要禁用一些scikit-learn的检查。 原因是BERT期望文本作为特征的输入,但scikit-learn期望数值特征。 因此,这些检查将会失败。 要禁用scikit-learn的检查,我们可以在本地层次分类器的构造函数中简单地使用参数bert=True

输出:

Building sklearn text classifier...

  0%|          | 0/231508 [00:00<?, ?B/s]
100%|##########| 231508/231508 [00:00<00:00, 5987008.39B/s]
Loading bert-base-uncased model...

  0%|          | 0/440473133 [00:00<?, ?B/s]
  1%|          | 3640320/440473133 [00:00<00:12, 36402280.03B/s]
  2%|2         | 10796032/440473133 [00:00<00:07, 57075891.31B/s]
  4%|3         | 17565696/440473133 [00:00<00:06, 61916127.56B/s]
  6%|5         | 24357888/440473133 [00:00<00:06, 64264800.52B/s]
  7%|7         | 31140864/440473133 [00:00<00:06, 65546158.21B/s]
  9%|8         | 37841920/440473133 [00:00<00:06, 66029030.11B/s]
 10%|#         | 44492800/440473133 [00:00<00:05, 66183323.16B/s]
 12%|#1        | 51233792/440473133 [00:00<00:05, 66572275.65B/s]
 13%|#3        | 58018816/440473133 [00:00<00:05, 66969339.97B/s]
 15%|#4        | 64775168/440473133 [00:01<00:05, 67152521.16B/s]
 16%|#6        | 71490560/440473133 [00:01<00:05, 67151645.98B/s]
 18%|#7        | 78326784/440473133 [00:01<00:05, 67519419.67B/s]
 19%|#9        | 85111808/440473133 [00:01<00:05, 67617598.75B/s]
 21%|##        | 91898880/440473133 [00:01<00:05, 67692508.24B/s]
 22%|##2       | 98673664/440473133 [00:01<00:05, 67707015.86B/s]
 24%|##3       | 105445376/440473133 [00:01<00:04, 67331371.27B/s]
 25%|##5       | 112179200/440473133 [00:01<00:04, 67047619.03B/s]
 27%|##7       | 118984704/440473133 [00:01<00:04, 67347467.26B/s]
 29%|##8       | 125720576/440473133 [00:01<00:04, 67283720.83B/s]
 30%|###       | 132450304/440473133 [00:02<00:04, 66790226.60B/s]
 32%|###1      | 139231232/440473133 [00:02<00:04, 67092935.25B/s]
 33%|###3      | 146279424/440473133 [00:02<00:04, 68103806.20B/s]
 35%|###4      | 153151488/440473133 [00:02<00:04, 68285356.10B/s]
 36%|###6      | 160133120/440473133 [00:02<00:04, 68739502.58B/s]
 38%|###7      | 167133184/440473133 [00:02<00:03, 69116005.56B/s]
 40%|###9      | 174251008/440473133 [00:02<00:03, 69731874.68B/s]
 41%|####1     | 181278720/440473133 [00:02<00:03, 69893372.83B/s]
 43%|####2     | 188268544/440473133 [00:02<00:03, 69831061.18B/s]
 44%|####4     | 195349504/440473133 [00:02<00:03, 70121215.45B/s]
 46%|####5     | 202418176/440473133 [00:03<00:03, 70288842.78B/s]
 48%|####7     | 209480704/440473133 [00:03<00:03, 70385911.92B/s]
 49%|####9     | 216519680/440473133 [00:03<00:03, 70269728.50B/s]
 51%|#####     | 223568896/440473133 [00:03<00:03, 70334721.82B/s]
 52%|#####2    | 230668288/440473133 [00:03<00:02, 70528917.38B/s]
 54%|#####3    | 237721600/440473133 [00:03<00:02, 70199413.70B/s]
 56%|#####5    | 244742144/440473133 [00:03<00:02, 70023316.11B/s]
 57%|#####7    | 251804672/440473133 [00:03<00:02, 70202521.27B/s]
 59%|#####8    | 258889728/440473133 [00:03<00:02, 70395140.42B/s]
 60%|######    | 265929728/440473133 [00:03<00:02, 70212913.44B/s]
 62%|######1   | 272951296/440473133 [00:04<00:02, 69875348.97B/s]
 64%|######3   | 280099840/440473133 [00:04<00:02, 70355025.86B/s]
 65%|######5   | 287190016/440473133 [00:04<00:02, 70517551.70B/s]
 67%|######6   | 294242304/440473133 [00:04<00:02, 69889147.88B/s]
 68%|######8   | 301233152/440473133 [00:04<00:02, 69080929.21B/s]
 70%|######9   | 308144128/440473133 [00:04<00:01, 68776509.80B/s]
 72%|#######1  | 315024384/440473133 [00:04<00:01, 68655990.41B/s]
 73%|#######3  | 321891328/440473133 [00:04<00:01, 68300515.91B/s]
 75%|#######4  | 328722432/440473133 [00:04<00:01, 67683747.77B/s]
 76%|#######6  | 335492096/440473133 [00:04<00:01, 67672690.43B/s]
 78%|#######7  | 342316032/440473133 [00:05<00:01, 67839757.80B/s]
 79%|#######9  | 349121536/440473133 [00:05<00:01, 67900579.02B/s]
 81%|########  | 355920896/440473133 [00:05<00:01, 67924691.74B/s]
 82%|########2 | 362714112/440473133 [00:05<00:01, 67749536.16B/s]
 84%|########3 | 369528832/440473133 [00:05<00:01, 67867761.76B/s]
 85%|########5 | 376315904/440473133 [00:05<00:00, 67107072.34B/s]
 87%|########6 | 383029248/440473133 [00:05<00:00, 66826879.98B/s]
 88%|########8 | 389795840/440473133 [00:05<00:00, 67073954.64B/s]
 90%|######### | 396505088/440473133 [00:05<00:00, 67072870.71B/s]
 92%|#########1| 403213312/440473133 [00:05<00:00, 66536066.08B/s]
 93%|#########3| 409868288/440473133 [00:06<00:00, 65935339.25B/s]
 95%|#########4| 416729088/440473133 [00:06<00:00, 66722624.05B/s]
 96%|#########6| 423504896/440473133 [00:06<00:00, 67026961.26B/s]
 98%|#########7| 430365696/440473133 [00:06<00:00, 67496639.86B/s]
 99%|#########9| 437116928/440473133 [00:06<00:00, 67220951.25B/s]
100%|##########| 440473133/440473133 [00:06<00:00, 67900437.62B/s]

  0%|          | 0/433 [00:00<?, ?B/s]
100%|##########| 433/433 [00:00<00:00, 3308075.83B/s]
Defaulting to linear classifier/regressor
Loading Pytorch checkpoint
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
This DataLoader will create 5 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
train data size: 2, validation data size: 0

Training  :   0%|          | 0/1 [00:00<?, ?it/s]This DataLoader will create 5 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, *, Number alpha = 1) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1581.)

Training  :   0%|          | 0/1 [00:04<?, ?it/s, loss=0.749]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.47s/it, loss=0.749]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.56s/it, loss=0.749]

Training  :   0%|          | 0/1 [00:00<?, ?it/s]
Training  :   0%|          | 0/1 [00:04<?, ?it/s, loss=0.703]
Training  : 100%|##########| 1/1 [00:04<00:00,  4.96s/it, loss=0.703]
Training  : 100%|##########| 1/1 [00:05<00:00,  5.19s/it, loss=0.703]

Training  :   0%|          | 0/1 [00:00<?, ?it/s]
Training  :   0%|          | 0/1 [00:05<?, ?it/s, loss=0.687]
Training  : 100%|##########| 1/1 [00:05<00:00,  5.05s/it, loss=0.687]
Training  : 100%|##########| 1/1 [00:05<00:00,  5.31s/it, loss=0.687]

Predicting:   0%|          | 0/1 [00:00<?, ?it/s]
Predicting: 100%|##########| 1/1 [00:00<00:00,  1.47it/s]
Predicting: 100%|##########| 1/1 [00:00<00:00,  1.26it/s]
[['Action' 'The Dark Night']
 ['Action' 'Watchmen']]

from bert_sklearn import BertClassifier
from hiclass import LocalClassifierPerParentNode

# Define data
X_train = X_test = [
    "Batman",
    "Rorschach",
]
Y_train = [
    ["Action", "The Dark Night"],
    ["Action", "Watchmen"],
]

# Use BERT for every node
bert = BertClassifier()
classifier = LocalClassifierPerParentNode(
    local_classifier=bert,
    bert=True,
)

# Train local classifier per node
classifier.fit(X_train, Y_train)

# Predict
predictions = classifier.predict(X_test)
print(predictions)

脚本的总运行时间: ( 0 分钟 31.635 秒)

Gallery generated by Sphinx-Gallery