英特尔® PyTorch* 扩展¶
创建于:2021年11月09日 | 最后更新:2024年7月25日 | 最后验证:2024年11月05日
Intel® Extension for PyTorch* 扩展了 PyTorch*,提供了最新的功能优化,以在 Intel 硬件上获得额外的性能提升。优化利用了 Intel CPU 上的 AVX-512 向量神经网络指令(AVX512 VNNI)和 Intel® 高级矩阵扩展(Intel® AMX),以及 Intel 独立 GPU 上的 Xe 矩阵扩展(XMX)AI 引擎。此外,通过 PyTorch* 的 xpu 设备,Intel® Extension for PyTorch* 为 Intel 独立 GPU 提供了简单的 GPU 加速支持。
Intel® Extension for PyTorch* 已在 Github 上作为开源项目发布。
CPU的源代码可在main branch获取。
GPU的源代码可在xpu-main branch获取。
功能¶
Intel® Extension for PyTorch* 在 CPU 和 GPU 上共享大部分功能。
易用的Python API: Intel® Extension for PyTorch* 提供了简单的前端Python API和工具,使用户能够通过少量代码更改获得性能优化,如图优化和操作符优化。通常,只需在原始代码中添加2到3条子句即可。
Channels Last: 与默认的NCHW内存格式相比,channels_last(NHWC)内存格式可以进一步加速卷积神经网络。在Intel® Extension for PyTorch*中,NHWC内存格式已经为大多数关键CPU操作符启用,尽管它们尚未全部合并到PyTorch主分支中。预计它们很快将完全集成到PyTorch上游。
自动混合精度 (AMP): 低精度数据类型 BFloat16 已在第三代 Xeon 可扩展服务器(也称为 Cooper Lake)上原生支持,并将在下一代 Intel® Xeon® 可扩展处理器上支持,配备 Intel® 高级矩阵扩展 (Intel® AMX) 指令集,性能进一步提升。在 Intel® PyTorch* 扩展中,已大量启用了对 CPU 的 BFloat16 自动混合精度 (AMP) 支持以及操作符的 BFloat16 优化,并部分上溯到 PyTorch 主分支。大多数这些优化将通过正在提交和审查的 PR 登陆 PyTorch 主分支。对于 Intel 独立 GPU,已启用了 BFloat16 和 Float16 的自动混合精度 (AMP)。
图优化: 为了通过torchscript进一步优化性能, Intel® Extension for PyTorch* 支持常用操作符模式的融合,如Conv2D+ReLU、Linear+ReLU等。融合的好处以透明的方式提供给用户。支持的详细融合模式可以在这里找到。 随着oneDNN Graph API的引入,图优化将被上游到PyTorch。
操作符优化: Intel® Extension for PyTorch* 还优化了操作符,并实现了几个定制的操作符以提高性能。一些 ATen 操作符通过 ATen 注册机制被 Intel® Extension for PyTorch* 中的优化版本所取代。此外,还为一些流行的拓扑结构实现了一些定制的操作符。例如,ROIAlign 和 NMS 在 Mask R-CNN 中定义。为了提高这些拓扑结构的性能,Intel® Extension for PyTorch* 还优化了这些定制的操作符。
入门指南¶
用户只需进行少量代码更改即可开始使用Intel® Extension for PyTorch*。PyTorch的命令式模式和TorchScript模式均受支持。本节介绍Intel® Extension for PyTorch* API函数在命令式模式和TorchScript模式下的使用,涵盖数据类型Float32和BFloat16。最后还将介绍C++的使用。
您只需要导入Intel® Extension for PyTorch*包,并将其优化函数应用于模型对象。如果是训练工作负载,优化函数还需要应用于优化器对象。
为了使用BFloat16数据类型进行训练和推理,torch.cpu.amp已在PyTorch上游启用,以方便支持混合精度。BFloat16数据类型已在PyTorch上游和Intel® Extension for PyTorch*中广泛启用用于CPU操作符。同时,由Intel® Extension for PyTorch*注册的torch.xpu.amp,使得在Intel独立GPU上轻松使用BFloat16和Float16数据类型成为可能。无论是torch.cpu.amp还是torch.xpu.amp,都会自动将每个操作符匹配到其适当的数据类型,并返回最佳性能。
示例 – CPU¶
本节展示了使用英特尔® PyTorch* 扩展在 CPU 上进行训练和推理的示例。
Intel® Extension for PyTorch* 所需的代码更改已突出显示。
训练¶
Float32¶
import torch
import torchvision
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
BFloat16¶
import torch
import torchvision
import intel_extension_for_pytorch as ipex
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
with torch.cpu.amp.autocast():
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
推理 - 命令模式¶
Float32¶
import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
model.eval()
data = torch.rand(1, 3, 224, 224)
#################### code changes ####################
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
######################################################
with torch.no_grad():
model(data)
BFloat16¶
import torch
from transformers import BertModel
model = BertModel.from_pretrained(args.model_name)
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes ####################
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
######################################################
with torch.no_grad():
with torch.cpu.amp.autocast():
model(data)
推理 - TorchScript 模式¶
TorchScript模式使得图优化成为可能,从而提高了某些拓扑结构的性能。Intel® Extension for PyTorch* 启用了最常用的操作符模式融合,用户无需额外的代码更改即可获得性能优势。
Float32¶
import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
model.eval()
data = torch.rand(1, 3, 224, 224)
#################### code changes ####################
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model)
######################################################
with torch.no_grad():
d = torch.rand(1, 3, 224, 224)
model = torch.jit.trace(model, d)
model = torch.jit.freeze(model)
model(data)
BFloat16¶
import torch
from transformers import BertModel
model = BertModel.from_pretrained(args.model_name)
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes ####################
import intel_extension_for_pytorch as ipex
model = ipex.optimize(model, dtype=torch.bfloat16)
######################################################
with torch.no_grad():
with torch.cpu.amp.autocast():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
示例 – GPU¶
本节展示了使用英特尔® PyTorch* 扩展在 GPU 上进行训练和推理的示例。
Intel® Extension for PyTorch* 所需的代码更改在上方的注释行中突出显示。
训练¶
Float32¶
import torch
import torchvision
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
#################################### code changes ################################
model = model.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.float32)
#################################### code changes ################################
for batch_idx, (data, target) in enumerate(train_loader):
########## code changes ##########
data = data.to("xpu")
target = target.to("xpu")
########## code changes ##########
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
BFloat16¶
import torch
import torchvision
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
LR = 0.001
DOWNLOAD = True
DATA = 'datasets/cifar10/'
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(
root=DATA,
train=True,
transform=transform,
download=DOWNLOAD,
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=128
)
model = torchvision.models.resnet50()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = LR, momentum=0.9)
model.train()
##################################### code changes ################################
model = model.to("xpu")
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
##################################### code changes ################################
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
######################### code changes #########################
data = data.to("xpu")
target = target.to("xpu")
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
######################### code changes #########################
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(batch_idx)
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
推理 - 命令模式¶
Float32¶
import torch
import torchvision.models as models
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = models.resnet50(pretrained=True)
model.eval()
data = torch.rand(1, 3, 224, 224)
model = model.to(memory_format=torch.channels_last)
data = data.to(memory_format=torch.channels_last)
#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float32)
#################### code changes ################
with torch.no_grad():
model(data)
BFloat16¶
import torch
import torchvision.models as models
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = models.resnet50(pretrained=True)
model.eval()
data = torch.rand(1, 3, 224, 224)
model = model.to(memory_format=torch.channels_last)
data = data.to(memory_format=torch.channels_last)
#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################
with torch.no_grad():
################################# code changes ######################################
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=False):
################################# code changes ######################################
model(data)
Float16¶
import torch
import torchvision.models as models
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = models.resnet50(pretrained=True)
model.eval()
data = torch.rand(1, 3, 224, 224)
model = model.to(memory_format=torch.channels_last)
data = data.to(memory_format=torch.channels_last)
#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################
with torch.no_grad():
################################# code changes ######################################
with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=False):
################################# code changes ######################################
model(data)
推理 - TorchScript 模式¶
TorchScript模式使得图优化成为可能,从而提高了某些拓扑结构的性能。Intel® Extension for PyTorch* 启用了最常用的操作符模式融合,用户无需额外的代码更改即可获得性能优势。
Float32¶
import torch
from transformers import BertModel
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = BertModel.from_pretrained(args.model_name)
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float32)
#################### code changes ################
with torch.no_grad():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
##### code changes #####
d = d.to("xpu")
##### code changes #####
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
BFloat16¶
import torch
from transformers import BertModel
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = BertModel.from_pretrained(args.model_name)
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes #################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.bfloat16)
#################### code changes #################
with torch.no_grad():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
################################# code changes ######################################
d = d.to("xpu")
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=False):
################################# code changes ######################################
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
Float16¶
import torch
from transformers import BertModel
############# code changes ###############
import intel_extension_for_pytorch as ipex
############# code changes ###############
model = BertModel.from_pretrained(args.model_name)
model.eval()
vocab_size = model.config.vocab_size
batch_size = 1
seq_length = 512
data = torch.randint(vocab_size, size=[batch_size, seq_length])
#################### code changes ################
model = model.to("xpu")
data = data.to("xpu")
model = ipex.optimize(model, dtype=torch.float16)
#################### code changes ################
with torch.no_grad():
d = torch.randint(vocab_size, size=[batch_size, seq_length])
################################# code changes ######################################
d = d.to("xpu")
with torch.xpu.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=False):
################################# code changes ######################################
model = torch.jit.trace(model, (d,), check_trace=False, strict=False)
model = torch.jit.freeze(model)
model(data)
C++ (仅限CPU)¶
为了使用PyTorch的C++库libtorch,Intel® Extension for PyTorch* 也提供了其C++动态库。该C++库仅用于处理推理工作负载,例如服务部署。对于常规开发,请使用Python接口。与使用libtorch相比,除了将输入数据转换为通道最后的数据格式外,不需要特定的代码更改。编译遵循使用CMake的推荐方法。详细说明可以在PyTorch教程中找到。 在编译期间,一旦链接了Intel® Extension for PyTorch* 的C++动态库,Intel优化将自动激活。
example-app.cpp
#include <torch/script.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
try {
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::vector<torch::jit::IValue> inputs;
// make sure input data are converted to channels last format
inputs.push_back(torch::ones({1, 3, 224, 224}).to(c10::MemoryFormat::ChannelsLast));
at::Tensor output = module.forward(inputs).toTensor();
return 0;
}
CMakeLists.txt
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(example-app)
find_package(intel_ext_pt_cpu REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
编译命令
$ cmake -DCMAKE_PREFIX_PATH=<LIBPYTORCH_PATH> ..
$ make
如果显示Found INTEL_EXT_PT_CPU为TRUE,则表示扩展已链接到二进制文件中。这可以通过Linux命令ldd来验证。
$ cmake -DCMAKE_PREFIX_PATH=/workspace/libtorch ..
-- The C compiler identification is GNU 9.3.0
-- The CXX compiler identification is GNU 9.3.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Found Torch: /workspace/libtorch/lib/libtorch.so
-- Found INTEL_EXT_PT_CPU: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /workspace/build
$ ldd example-app
...
libtorch.so => /workspace/libtorch/lib/libtorch.so (0x00007f3cf98e0000)
libc10.so => /workspace/libtorch/lib/libc10.so (0x00007f3cf985a000)
libintel-ext-pt-cpu.so => /workspace/libtorch/lib/libintel-ext-pt-cpu.so (0x00007f3cf70fc000)
libtorch_cpu.so => /workspace/libtorch/lib/libtorch_cpu.so (0x00007f3ce16ac000)
...
libdnnl_graph.so.0 => /workspace/libtorch/lib/libdnnl_graph.so.0 (0x00007f3cde954000)
...
模型库(仅限CPU)¶
已经由英特尔工程师优化的用例可在
Model Zoo for Intel® Architecture(分支名称格式为pytorch-r