torch.signal.windows.windows 的源代码
from typing import Optional, Iterable
import torch
from math import sqrt
from torch import Tensor
from torch._torch_docs import factory_common_args, parse_kwargs, merge_dicts
__all__ = [
'bartlett',
'blackman',
'cosine',
'exponential',
'gaussian',
'general_cosine',
'general_hamming',
'hamming',
'hann',
'kaiser',
'nuttall',
]
window_common_args = merge_dicts(
parse_kwargs(
"""
M (int): 窗口的长度。
换句话说,返回窗口的点数。
sym (bool, 可选): 如果为 `False`,返回适合在频谱分析中使用的周期性窗口。
如果为 `True`,返回适合在滤波器设计中使用的对称窗口。默认值: `True`。
"""
),
factory_common_args,
{
"normalization": "窗口被归一化为 1(最大值为 1)。然而,如果 :attr:`M` 是偶数且 :attr:`sym` 为 `True`,则不会出现 1。",
}
)
def _add_docstr(*args):
r"""为给定的装饰函数添加文档字符串。
当文档字符串需要字符串插值时特别有用,例如使用 str.format()。
注意:如果文档字符串不需要字符串插值,请直接编写常规文档字符串。
参数:
args (str):
"""
def decorator(o):
o.__doc__ = "".join(args)
return o
return decorator
def _window_function_checks(function_name: str, M: int, dtype: torch.dtype, layout: torch.layout) -> None:
r"""为所有定义的窗口执行通用检查。
在计算任何窗口之前应调用此函数。
参数:
function_name (str): 窗口函数的名称。
M (int): 窗口的长度。
dtype (:class:`torch.dtype`): 返回张量的所需数据类型。
layout (:class:`torch.layout`): 返回张量的所需布局。
"""
if M < 0:
raise ValueError(f'{function_name} 需要非负的窗口长度,得到 M={M}')
if layout is not torch.strided:
raise ValueError(f'{function_name} 仅针对分步张量实现,得到: {layout}')
if dtype not in [torch.float32, torch.float64]:
raise ValueError(f'{function_name} 期望 float32 或 float64 数据类型,得到: {dtype}')
[docs]@_add_docstr(
r"""
计算具有指数波形的窗口。
也称为泊松窗口。
指数窗口定义如下:
.. math::
w_n = \exp{\left(-\frac{|n - c|}{\tau}\right)}
其中 `c` 是窗口的 ``center``。
""",
r"""
{normalization}
参数:
{M}
关键字参数:
center (float, 可选): 窗口中心的位置。
默认值: `M / 2` 如果 `sym` 为 `False`,否则 `(M - 1) / 2`。
tau (float, 可选): 衰减值。
Tau 通常与百分比相关联,这意味着值应在区间 (0, 100] 内变化。如果 tau 为 100,则认为是均匀窗口。
默认值: 1.0。
{sym}
{dtype}
{layout}
{device}
{requires_grad}
示例::
>>> # 生成一个大小为 10 且衰减值为 1.0 的对称指数窗口。
>>> # 中心将在 (M - 1) / 2,其中 M 为 10。
>>> torch.signal.windows.exponential(10)
tensor([0.0111, 0.0302, 0.0821, 0.2231, 0.6065, 0.6065, 0.2231, 0.0821, 0.0302, 0.0111])
>>> # 生成一个周期性指数窗口和衰减因子等于 .5
>>> torch.signal.windows.exponential(10, sym=False,tau=.5)
tensor([4.5400e-05, 3.3546e-04, 2.4788e-03, 1.8316e-02, 1.3534e-01, 1.0000e+00, 1.3534e-01, 1.8316e-02, 2.4788e-03, 3.3546e-04])
""".format(
**window_common_args
),
)
def exponential(
M: int,
*,
center: Optional[float] = None,
tau: float = 1.0,
sym: bool = True,
dtype: Optional[torch.dtype] = None,
layout: torch.layout = torch.strided,
device: Optional[torch.device] = None,
requires_grad: bool = False
) -> Tensor:
if dtype is None:
dtype <span class="o