收集¶
收集 - 13¶
版本¶
名称: Gather (GitHub)
域名:
mainsince_version:
13函数:
Falsesupport_level:
SupportType.COMMON形状推断:
True
此版本的运算符自版本13起可用。
摘要¶
给定秩为 r >= 1 的 data 张量,以及秩为 q 的 indices 张量,收集由 indices 索引的 data 的轴维度(默认最外层轴为 axis=0)的条目,并将它们连接成一个秩为 q + (r - 1) 的输出张量。
这是一个索引操作,它沿着一个(指定的)轴对输入的data进行索引。
indices中的每个条目都会生成输入张量的一个r-1维切片。
整个操作在概念上生成一个由r-1维切片组成的q维张量,
该张量被排列成一个q + (r-1)维张量,其中q维度取代了原始被索引的axis。
以下几个例子说明了Gather如何针对特定形状的data、indices以及给定的axis值工作:
数据形状 |
索引形状 |
轴 |
输出形状 |
输出方程 |
|---|---|---|---|---|
(P, Q) |
( ) (一个标量) |
0 |
(问题) |
输出[q] = 数据[索引, q] |
(P, Q, R) |
( ) (一个标量) |
1 |
(P, R) |
输出[p, r] = 数据[p, 索引, r] |
(P, Q) |
(R, S) |
0 |
(R, S, Q) |
输出[r, s, q] = 数据[ [索引[r, s], q] |
(P, Q) |
(R, S) |
1 |
(P, R, S) |
输出[p, r, s] = 数据[ p, 索引[r, s]] |
更一般地,如果 axis = 0,设 k = indices[i_{0}, ..., i_{q-1}]
则 output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}]:
data = [
[1.0, 1.2],
[2.3, 3.4],
[4.5, 5.7],
]
indices = [
[0, 1],
[1, 2],
]
output = [
[
[1.0, 1.2],
[2.3, 3.4],
],
[
[2.3, 3.4],
[4.5, 5.7],
],
]
如果 axis = 1,设 k = indices[i_{0}, ..., i_{q-1}]
则 output[j_{0}, i_{0}, ..., i_{q-1}, j_{1}, ..., j_{r-2}] = input[j_{0}, k, j_{1}, ..., j_{r-2}]:
data = [
[1.0, 1.2, 1.9],
[2.3, 3.4, 3.9],
[4.5, 5.7, 5.9],
]
indices = [
[0, 2],
]
axis = 1,
output = [
[[1.0, 1.9]],
[[2.3, 3.9]],
[[4.5, 5.9]],
]
属性¶
axis - INT (默认为
'0'):在哪个轴上收集。负值表示从后面开始计算维度。可接受的范围是[-r, r-1],其中r = rank(data)。
输入¶
data (异构) - T:
秩为 r >= 1 的张量。
indices(异构) - Tind:
int32/int64索引的张量,可以是任何秩q。所有索引值都应在大小为s的轴上的边界[-s, s-1]内。如果任何索引值超出边界,则会出现错误。
输出¶
输出 (异构) - T:
秩为 q + (r - 1) 的张量。
类型约束¶
T 在 (
tensor(bfloat16),tensor(bool),tensor(complex128),tensor(complex64),tensor(double),tensor(float),tensor(float16),tensor(int16),tensor(int32),tensor(int64),tensor(int8),tensor(string),tensor(uint16),tensor(uint32),tensor(uint64),tensor(uint8)):将输入和输出类型限制为任何张量类型。
Tind 在 (
tensor(int32),tensor(int64)) 中:将索引限制为整数类型
收集 - 11¶
版本¶
名称: Gather (GitHub)
域名:
mainsince_version:
11函数:
Falsesupport_level:
SupportType.COMMON形状推断:
True
此版本的运算符自版本11起可用。
摘要¶
给定秩为 r >= 1 的 data 张量,以及秩为 q 的 indices 张量,收集由 indices 索引的 data 的轴维度(默认最外层轴为 axis=0)的条目,并将它们连接成一个秩为 q + (r - 1) 的输出张量。
axis = 0 :
让 k = indices[i_{0}, …, i_{q-1}] 然后 output[i_{0}, …, i_{q-1}, j_{0}, …, j_{r-2}] = input[k , j_{0}, …, j_{r-2}]
data = [
[1.0, 1.2],
[2.3, 3.4],
[4.5, 5.7],
]
indices = [
[0, 1],
[1, 2],
]
output = [
[
[1.0, 1.2],
[2.3, 3.4],
],
[
[2.3, 3.4],
[4.5, 5.7],
],
]
轴 = 1 :
让 k = indices[i_{0}, …, i_{q-1}] 然后 output[j_{0}, i_{0}, …, i_{q-1}, j_{1}, …, j_{r-2}] = input[j_{0}, k, j_{1}, …, j_{r-2}]
data = [
[1.0, 1.2, 1.9],
[2.3, 3.4, 3.9],
[4.5, 5.7, 5.9],
]
indices = [
[0, 2],
]
axis = 1,
output = [
[[1.0, 1.9]],
[[2.3, 3.9]],
[[4.5, 5.9]],
]
属性¶
axis - INT (默认为
'0'):在哪个轴上收集。负值表示从后面开始计算维度。可接受的范围是[-r, r-1],其中r = rank(data)。
输入¶
data (异构) - T:
秩为 r >= 1 的张量。
indices(异构) - Tind:
int32/int64索引的张量,可以是任何秩q。所有索引值都应在大小为s的轴上的边界[-s, s-1]内。如果任何索引值超出边界,则会出现错误。
输出¶
输出 (异构) - T:
秩为 q + (r - 1) 的张量。
类型约束¶
T 在 (
tensor(bool),tensor(complex128),tensor(complex64),tensor(double),tensor(float),tensor(float16),tensor(int16),tensor(int32),tensor(int64),tensor(int8),tensor(string),tensor(uint16),tensor(uint32),tensor(uint64),tensor(uint8)):将输入和输出类型限制为任何张量类型。
Tind 在 (
tensor(int32),tensor(int64)) 中:将索引限制为整数类型
收集 - 1¶
版本¶
名称: Gather (GitHub)
域名:
mainsince_version:
1函数:
Falsesupport_level:
SupportType.COMMON形状推断:
True
此版本的运算符自版本1起可用。
总结¶
给定秩为 r >= 1 的 data 张量,以及秩为 q 的 indices 张量,收集由 indices 索引的 data 的轴维度(默认最外层轴为 axis=0)的条目,并将它们连接成一个秩为 q + (r - 1) 的输出张量。
示例 1:
data = [
[1.0, 1.2],
[2.3, 3.4],
[4.5, 5.7],
]
indices = [
[0, 1],
[1, 2],
]
output = [
[
[1.0, 1.2],
[2.3, 3.4],
],
[
[2.3, 3.4],
[4.5, 5.7],
],
]
示例 2:
data = [
[1.0, 1.2, 1.9],
[2.3, 3.4, 3.9],
[4.5, 5.7, 5.9],
]
indices = [
[0, 2],
]
axis = 1,
output = [
[[1.0, 1.9]],
[[2.3, 3.9]],
[[4.5, 5.9]],
]
属性¶
axis - INT (默认为
'0'):在哪个轴上收集。负值表示从后面开始计算维度。接受的范围是 [-r, r-1]
输入¶
data (异构) - T:
秩为 r >= 1 的张量。
indices(异构) - Tind:
int32/int64索引的张量,可以是任何秩q。所有索引值都应在范围内。如果任何索引值超出范围,则是一个错误。
输出¶
输出 (异构) - T:
秩为 q + (r - 1) 的张量。
类型约束¶
T 在 (
tensor(bool),tensor(complex128),tensor(complex64),tensor(double),tensor(float),tensor(float16),tensor(int16),tensor(int32),tensor(int64),tensor(int8),tensor(string),tensor(uint16),tensor(uint32),tensor(uint64),tensor(uint8)):将输入和输出类型限制为任何张量类型。
Tind 在 (
tensor(int32),tensor(int64)) 中:将索引限制为整数类型