Shortcuts

ding.utils

自动日志

请参考 ding/utils/autolog 获取更多详细信息。

时间模式

class ding.utils.autolog.TimeMode(value)[source]
Overview:

用于决定range_values函数格式的模式

ABSOLUTE: 使用绝对时间 RELATIVE_LIFECYCLE: 使用基于属性生命周期的相对时间 RELATIVE_CURRENT_TIME: 使用基于当前时间的相对时间

ABSOLUTE = 0
RELATIVE_CURRENT_TIME = 2
RELATIVE_LIFECYCLE = 1

范围数据

class ding.utils.autolog.RangedData(expire: float, use_pickle: bool = False)[source]
Overview:

一种可以存储数据一段时间的数据结构。

Interfaces:

__init__, append, extend, current, history, expire, __bool__, _get_time.

Properties:
  • expire (float): 过期时间。

__append(time_: float, data: _Tp)
Overview:

追加数据。

__append_item(time_: float, data: _Tp)
Overview:

追加数据项。

Arguments:
  • time_ (float): 时间。

  • 数据 (_Tp): 数据项。

__check_expire()
Overview:

检查过期时间。

__check_time(time_: float)
Overview:

检查时间。

Arguments:
  • time_ (float): 时间。

__current()
Overview:

获取当前数据。

__flush_history()
Overview:

清除历史数据。

__get_data_item(data_id: int) _Tp
Overview:

获取数据项。

Arguments:
  • data_id (int): 数据ID。

__history()
Overview:

获取历史数据。

__history_yield()
Overview:

生成历史数据。

__init__(expire: float, use_pickle: bool = False)[source]
Overview:

初始化RangedData对象。

Arguments:
  • expire (float): 数据的过期时间。

  • use_pickle (bool): 是否使用pickle来序列化数据。

__registry_data_item(data: _Tp) int
Overview:

注册数据项。

Arguments:
  • 数据 (_Tp): 数据项。

__remove_data_item(data_id: int)
Overview:

删除数据项。

Arguments:
  • data_id (int): 数据ID。

_abc_impl = <_abc._abc_data object>
abstract _get_time() float[source]
Overview:

获取当前时间。

append(data: _Tp)[source]
Overview:

追加数据。

current() _Tp[source]
Overview:

获取当前数据。

property expire: float
Overview:

获取过期时间。

extend(iter_: Iterable[_Tp])[来源]
Overview:

扩展数据。

history() List[Tuple[int | float, _Tp]][source]
Overview:

获取历史数据。

时间范围数据

class ding.utils.autolog.TimeRangedData(time_: BaseTime, expire: float)[来源]
Overview:

一种可以存储数据一段时间的数据结构。

Interfaces:

__init__, _get_time, append, extend, current, history, expire, __bool__.

Properties:
  • 时间 (BaseTime): 时间。

  • expire (float): 过期时间。

__init__(time_: BaseTime, expire: float)[source]
Overview:

初始化 TimeRangedData 对象。

Arguments:
_abc_impl = <_abc._abc_data object>
_get_time() float[source]
Overview:

获取当前时间。

property time
Overview:

获取时间。

已记录模型

class ding.utils.autolog.LoggedModel(time_: _TimeObjectType, expire: _TimeType)[source]
Overview:

一个带有时间线的模型(整数时间,如第1、第2、第3,也可以建模为一种自定义的离散时间,如TickTime的实现)。通过使用LoggedModel,可以一起维护几个相互关联的值。

Example:

像这样定义AvgList模型

>>> from ding.utils.autolog import LoggedValue, LoggedModel
>>> class AvgList(LoggedModel):
>>>     value = LoggedValue(float)
>>>     __property_names = ['value']
>>>
>>>     def __init__(self, time_: BaseTime, expire: Union[int, float]):
>>>         LoggedModel.__init__(self, time_, expire)
>>>         # attention, original value must be set in __init__ function, or it will not
>>>         # be activated, the timeline of this value will also be unexpectedly affected.
>>>         self.value = 0.0
>>>         self.__register()
>>>
>>>     def __register(self):
>>>         def __avg_func(prop_name: str) -> float:  # function to calculate average value of properties
>>>             records = self.range_values[prop_name]()
>>>             (_start_time, _), _ = records[0]
>>>             (_, _end_time), _ = records[-1]
>>>
>>>             _duration = _end_time - _start_time
>>>             _sum = sum([_value * (_end_time - _begin_time) for (_begin_time, _end_time), _value in records])
>>>
>>>             return _sum / _duration
>>>
>>>         for _prop_name in self.__property_names:
>>>             self.register_attribute_value('avg', _prop_name, partial(__avg_func, prop_name=_prop_name))

像这样使用它

>>> from ding.utils.autolog import NaturalTime, TimeMode
>>>
>>> if __name__ == "__main__":
>>>     _time = NaturalTime()
>>>     ll = AvgList(_time, expire=10)
>>>
>>>     # just do something here ...
>>>
>>>     print(ll.range_values['value']()) # original range_values function in LoggedModel of last 10 secs
>>>     print(ll.range_values['value'](TimeMode.ABSOLUTE))  # use absolute time
>>>     print(ll.avg['value']())  # average value of last 10 secs
Interfaces:

__init__, time, expire, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__, get_property_attribute

Property:
  • 时间 (BaseTime): 时间。

  • expire (float): 过期时间。

__get_property_ranged_data(name: str) TimeRangedData
Overview:

获取一个属性的范围数据。

Arguments:
  • 名称 (str): 属性名称。

__get_range_values_func(name: str)
Overview:

获取一个属性的range_values函数。

Arguments:
  • 名称 (str): 属性名称。

__init__(time_: _TimeObjectType, expire: _TimeType)[来源]
Overview:

使用给定的参数初始化LoggedModel对象。

Arguments:
__init_properties()
Overview:

初始化所有属性。

property __properties: List[str]
Overview:

获取所有属性名称。

__register_default_funcs()
Overview:

注册默认函数。

_abc_impl = <_abc._abc_data object>
current_time() float | int[source]
Overview:

获取当前时间(不考虑时间代理的冻结声明的实时时间)

Returns:

int 或 float: 当前时间

property expire: _TimeType
Overview:

获取过期时间

Returns:

int 或 float: 旧值记录过期的时间

fixed_time() float | int[source]
Overview:

获取固定时间(如果时间代理被冻结,则将是冻结时间) 当添加值回放功能时,此功能可能很有用(在未来)

Returns:

int 或 float: 固定时间

freeze()[source]
Overview:

冻结时间代理对象。 当添加值回放功能时,此功能可能很有用(在未来)

get_property_attribute(property_name: str) List[str][来源]
Overview:

查找给定属性的所有已注册属性(不包括常见的“range_values”属性,因为“range_values”未添加到self.__prop2attr中)。

Arguments:
  • property_name (str): 查询属性的属性名称

Returns:
  • attr_list (List[str]): 输入属性的注册属性列表

register_attribute_value(attribute_name: str, property_name: str, value: Any)[source]
Overview:

为其中一个值注册一个新属性。示例可以在类的概述中找到。

Arguments:
  • 属性名称 (str): 属性的名称

  • property_name (str): 属性名称

  • 值 (Any): 属性的值

property time: _TimeObjectType
Overview:

获取传入的原始时间对象,可以通过此属性执行方法(例如 step())。

Returns:

BaseTime: 此模型使用的时间对象

unfreeze()[来源]
Overview:

解冻时间代理对象。 此功能在添加值回放功能时可能有用(未来)

基础时间

class ding.utils.autolog.BaseTime[source]
Overview:

抽象时间接口

Interfaces:

time

_abc_impl = <_abc._abc_data object>
abstract time() int | float[source]
Overview:

获取时间信息

Returns:
  • 时间(float, int): 时间信息

自然时间

class ding.utils.autolog.NaturalTime[来源]
Overview:

自然时间对象

Interfaces:

__init__, time

Example:
>>> from ding.utils.autolog.time_ctl import NaturalTime
>>> time_ = NaturalTime()
__init__()[source]
_abc_impl = <_abc._abc_data object>
time() float[source]
Overview:

获取当前自然时间(浮点格式,Unix时间戳)

Returns:
  • 时间(float): Unix时间戳

Example:
>>> from ding.utils.autolog.time_ctl import NaturalTime
>>> time_ = NaturalTime()
>>> time_.time()
1603896383.8811457

TickTime

class ding.utils.autolog.TickTime(init: int = 0)[source]
Overview:

时间刻度对象

Interfaces:

__init__, step, time

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime()
__init__(init: int = 0)[source]
Overview:

TickTime的构造函数

Arguments:
  • init (int): 初始时间,默认为0

_abc_impl = <_abc._abc_data object>
step(delta: int = 1) int[source]
Overview

为这个TickTime向前推进时间

Arguments:
  • delta (int): 向前步进的步数,默认为1

Returns:
  • 时间 (int): 步进后的新时间

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime(0)
>>> time_.step()
1
>>> time_.step(2)
3
time() int[source]
Overview

获取当前刻度时间

Returns:

int: 当前刻度时间

Example:
>>> from ding.utils.autolog.time_ctl import TickTime
>>> time_ = TickTime(0)
>>> time_.step()
>>> time_.time()
1

时间代理

class ding.utils.autolog.TimeProxy(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[source]
Overview:

时间对象的代理,它可以冻结时间,有时在重现时很有用。 该对象是线程安全的,并且冻结和解冻操作是严格有序的。

Interfaces:

__init__, freeze, unfreeze, time, current_time

Examples:
>>> from ding.utils.autolog.time_ctl import TickTime, TimeProxy
>>> tick_time_ = TickTime()
>>> time_ = TimeProxy(tick_time_)
>>> tick_time_.step()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
1 1 1
>>> time_.freeze()
>>> tick_time_.step()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
2 1 2
>>> time_.unfreeze()
>>> print(tick_time_.time(), time_.time(), time_.current_time())
2 2 2
__init__(time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK)[source]
Overview:

Time代理的构造函数

Arguments:
  • time_ (BaseTime): 另一个基于它的时间对象

  • frozen (bool): 如果为真,此对象将立即被冻结,否则不会,默认为 False

  • lock_type (LockContextType): 锁的类型,默认为 THREAD_LOCK

_abc_impl = <_abc._abc_data object>
current_time() int | float[source]
Overview:

获取当前时间(不会冻结时间)

Returns:

int 或 float: 当前时间

freeze()[源代码]
Overview:

冻结此次代理

property is_frozen: bool
Overview:

获取此时代理对象是否被冻结

Returns:

布尔值:如果被冻结则为 true,否则为 false

time() int | float[来源]
Overview:

获取时间(可能是冻结时间)

Returns:

int 或 float: 时间

unfreeze()[source]
Overview:

解冻此次代理

记录值

class ding.utils.autolog.LoggedValue(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[source]
Overview:

LoggedValue 可以作为 LoggedModel 中的属性使用,因为它具有 __get__ 和 __set__ 方法。 此类的实例将与它们的所有者 LoggedModel 实例相关联,一个 LoggedModel 的所有 LoggedValue 将共享唯一的时间对象(在 time_ctl 中定义),以便可以正确管理时间线。

Interfaces:

__init__, __get__, __set__

Properties:
  • __property_name (str): 属性的名称。

__get_ranged_data(instance) TimeRangedData
Overview:

获取范围数据。

Interfaces:

__get_ranged_data

__init__(type_: ~typing.Type[~ding.utils.autolog.base._ValueType] = <class 'object'>)[源代码]
Overview:

初始化 LoggedValue 对象。

Interfaces:

__init__

property __property_name
Overview:

获取属性的名称。

数据结构

请参考 ding/utils/data/structure 获取更多详细信息。

缓存

class ding.utils.data.structure.Cache(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False)[source]
Overview:

用于减少并发压力的数据缓存,具有超时和队列满时弹出机制

Interfaces:

__init__, push_data, get_cached_data_iter, run, close

Property:

剩余数据计数

__init__(maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False) None[源代码]
Overview:

初始化缓存对象。

Arguments:
  • maxlen (int): 缓存队列的最大长度。

  • 超时时间 (float): 数据在缓存中可以保留的最大秒数。

  • monitor_interval (float): 超时监控线程检查时间的间隔。

  • _debug (bool): 是否使用调试模式,该模式会启用调试打印信息。

_timeout_monitor() None[source]
Overview:

超时监控线程的工作流程。

_warn_if_timeout() bool[source]
Overview:

返回是否超时。

Returns
  • 结果: (bool) 是否超时。

close() None[source]
Overview:

关闭缓存内部线程并向发送队列的迭代器发送结束标志。

dprint(s: str) None[源代码]
Overview:

在调试模式下,打印调试字符串。

Arguments:
  • s (str): 要打印的调试信息。

get_cached_data_iter() callable_iterator[source]
Overview:

获取发送队列的迭代器。一旦数据被推入发送队列,就可以通过这个迭代器访问它。'STOP' 是这个迭代器的结束标志。

Returns:
  • 迭代器 (callable_iterator) 发送队列迭代器。

push_data(data: Any) None[source]
Overview:

将数据推入接收队列,如果接收队列已满(在推送后),则将接收队列中的所有数据推入发送队列。

Arguments:
  • 数据 (Any): 需要添加到接收队列中的数据

提示

线程安全

property remain_data_count: int
Overview:

返回接收队列的剩余数据计数

Returns:
  • count (int): 接收队列的大小。

run() None[source]
Overview:

启动缓存内部线程,例如超时监控线程。

LifoDeque

class ding.utils.data.structure.LifoDeque(maxsize=0)[源代码]
Overview:

类似于LifoQueue,但当队列满时自动替换最旧的数据。

Interfaces:

_init, _put, _get

_init(maxsize)[source]

data.base_dataloader

请参考 ding/utils/data/base_dataloader 了解更多详情。

IDataLoader

class ding.utils.data.base_dataloader.IDataLoader[来源]
Overview:

数据加载器的基类

Interfaces:

__init__, __next__, __iter__, _get_data, close

_get_data(batch_size: int | None = None) List[Tensor][source]
Overview:

获取一批数据

Arguments:
  • batch_size (Optional[int]): 有时,batch_size 由每次迭代指定,如果 batch_size 为 None,则使用默认的 batch_size 值

close() None[source]
Overview:

关闭数据加载器

data.collate_fn

请参考 ding/utils/data/collate_fn 了解更多详情。

ttorch_collate

ding.utils.data.collate_fn.ttorch_collate(x, json: bool = False, cat_1dim: bool = True)[source]
Overview:

将张量或张量的嵌套字典列表整理成单个张量或张量的嵌套字典。

Arguments:
  • x : 输入的张量列表或张量的嵌套字典。

  • json (bool): 如果为True,将输出转换为JSON格式。默认为False。

  • cat_1dim (bool): 如果为True,则沿最后一个维度连接形状为(B, 1)的张量。默认为True。

Returns:

整理后的输出张量或张量的嵌套字典。

Examples:
>>> # case 1: Collate a list of tensors
>>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
>>> collated = ttorch_collate(tensors)
collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> # case 2: Collate a nested dictionary of tensors
>>> nested_dict = {
        'a': torch.tensor([1, 2, 3]),
        'b': torch.tensor([4, 5, 6]),
        'c': torch.tensor([7, 8, 9])
    }
>>> collated = ttorch_collate(nested_dict)
collated = {
    'a': torch.tensor([1, 2, 3]),
    'b': torch.tensor([4, 5, 6]),
    'c': torch.tensor([7, 8, 9])
}
>>> # case 3: Collate a list of nested dictionaries of tensors
>>> nested_dicts = [
        {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])},
        {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])}
    ]
>>> collated = ttorch_collate(nested_dicts)
collated = {
    'a': torch.tensor([[1, 2, 3], [7, 8, 9]]),
    'b': torch.tensor([[4, 5, 6], [10, 11, 12]])
}

default_collate

ding.utils.data.collate_fn.default_collate(batch: Sequence, cat_1dim: bool = True, ignore_prefix: list = ['collate_ignore']) Tensor | Mapping | Sequence[source]
Overview:

将每个数据字段放入一个具有外部维度批量大小的张量中。

Arguments:
  • batch (Sequence): 一个数据序列,其长度为批量大小,其元素为一条数据。

  • cat_1dim (bool): 是否将形状为 (B, 1) 的张量连接为 (B),默认为 True。

  • ignore_prefix (list): 一个在整理字典时要忽略的前缀列表,默认值为 [‘collate_ignore’]。

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): 整理后的数据,每个数据字段都包含批量大小。返回的数据类型取决于原始元素的数据类型,可以是 [torch.Tensor, Mapping, Sequence]。

Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> default_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> default_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = default_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

timestep_collate

ding.utils.data.collate_fn.timestep_collate(batch: List[Dict[str, Any]]) Dict[str, Tensor | list][source]
Overview:

将一批时间步数据字段整理成张量,外部维度为批量大小。每个时间步数据字段表示为形状为[T, B, any_dims]的张量,其中T是序列的长度,B是批量大小,any_dims表示每个时间步张量的形状。

Arguments:
  • batch(List[Dict[str, Any]]): 一个长度为B的字典列表,其中每个字典代表一个时间步长的数据字段。每个字典包含一个键值对,其中键是数据字段的名称,值是具有任何形状的torch.Tensor对象序列。

Returns:
  • ret(Dict[str, Union[torch.Tensor, list]]): 整理后的数据,每个数据字段都包含了时间步长和批量大小。每个数据字段的形状为 [T, B, dim1, dim2, …]。

Examples:
>>> batch = [
        {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]},
        {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]}
    ]
>>> collated_data = timestep_collate(batch)
>>> print(collated_data['data'].shape)
torch.Size([2, 2, 3])

diff_shape_collate

ding.utils.data.collate_fn.diff_shape_collate(batch: Sequence) Tensor | Mapping | Sequence[source]
Overview:

整理一批具有不同形状的数据。 此函数类似于default_collate,但它允许批次中的张量具有None值,这在星际争霸观察中很常见。

Arguments:
  • batch (Sequence): 一个数据序列,其中每个元素都是一条数据。

Returns:
  • ret (Union[torch.Tensor, Mapping, Sequence]): 整理后的数据,每个数据字段都应用了批量大小。返回类型取决于原始元素类型,可以是 torch.Tensor、Mapping 或 Sequence。

Examples:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> diff_shape_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> diff_shape_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = diff_shape_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])

default_decollate

ding.utils.data.collate_fn.default_decollate(batch: Tensor | Sequence | Mapping, ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state']) List[Any][源代码]
Overview:

拖出batch_size整理数据的批次大小以解整理它,这是default_collate的反向操作。

Arguments:
  • batch (Union[torch.Tensor, Sequence, Mapping]): 整理后的数据批次。它可以是一个张量、序列或映射。

  • 忽略(List[str]): 一个需要忽略的名称列表。仅当输入 batch 是字典时适用。如果键在此列表中,其值将保持不变,不会进行解构。默认值为 [‘prev_state’, ‘prev_actor_state’, ‘prev_critic_state’].

Returns:
  • ret (List[Any]): 一个包含B个元素的列表,其中B是批量大小。

Examples:
>>> batch = {
    'a': [
        [1, 2, 3],
        [4, 5, 6]
    ],
    'b': [
        [7, 8, 9],
        [10, 11, 12]
    ]}
>>> default_decollate(batch)
{
    0: {'a': [1, 2, 3], 'b': [7, 8, 9]},
    1: {'a': [4, 5, 6], 'b': [10, 11, 12]},
}

data.dataloader

请参考 ding/utils/data/dataloader 获取更多详细信息。

AsyncDataLoader

class ding.utils.data.dataloader.AsyncDataLoader(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0)[源代码]
Overview:

一个异步数据加载器。

Interfaces:

__init__, __iter__, __next__, _get_data, _async_loop, _worker_loop, _cuda_loop, _get_data, close

__init__(data_source: Callable | dict, batch_size: int, device: str, chunk_size: int | None = None, collate_fn: Callable | None = None, num_workers: int = 0) None[源代码]
Overview:

使用输入参数初始化数据加载器。 如果data_sourcedict,数据将仅在get_data_thread中处理并放入async_train_queue。 如果data_sourceCallable,数据将通过实现函数进行处理,并且可以分为两种类型:

  • num_workers == 0 或 1: 只有主工作进程会处理它并将其放入 async_train_queue

  • num_workers > 1: 主工作器将任务分成若干部分,将每个任务推入 job_queue; 然后从工作器获取任务并执行; 最后它们将处理后的数据推入 async_train_queue.

在最后一步,如果device包含“cuda”,async_train_queue中的数据将被传输到cuda_queue以供用户访问。

Arguments:
  • data_source (Union[Callable, dict]): 数据源,例如要实现的函数(Callable),重放缓冲区的真实数据(dict)等。

  • batch_size (int): 批量大小。

  • 设备 (str): 设备。

  • chunk_size (int): 批次中分块的大小,应能精确整除 batch_size,仅在有多于1个工作者时起作用。

  • collate_fn (Callable): 用于将批量大小整理到每个数据字段中的函数。

  • num_workers (int): 额外工作者的数量。0 或 1 表示只有一个主要工作者,没有额外的,即多处理被禁用。大于 1 表示通过多处理实现的多个工作者将分别处理数据。

_async_loop(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[source]
Overview:

主工作进程。通过self.async_process运行。 首先,从self.get_data_thread获取数据。 如果有多个工作进程,将数据放入self.job_queue以进行进一步的多进程操作; 如果只有一个工作进程,处理数据并直接放入self.async_train_queue

Arguments:
  • p (tm.multiprocessing.connection): 父连接。

  • c (tm.multiprocessing.connection): 子连接。

_cuda_loop() None[source]
Overview:

只有在使用cuda时,才会通过self.cuda_thread作为线程运行。 从self.async_train_queue获取数据,更改其设备并将其放入self.cuda_queue

_get_data(p: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>, c: <module 'multiprocessing.connection' from '/home/docs/.asdf/installs/python/3.9.20/lib/python3.9/multiprocessing/connection.py'>) None[source]
Overview:

使用输入参数初始化数据加载器。将通过self.get_data_thread作为线程运行。

Arguments:
  • p (tm.multiprocessing.connection): 父连接。

  • c (tm.multiprocessing.connection): 子连接。

_worker_loop() None[source]
Overview:

工作进程。遍历列表 self.worker 中的每个元素。 从 self.job_queue 获取数据任务,处理它,然后放入 self.async_train_queue。 仅在 self.num_workers > 1 时起作用,这意味着使用多进程。

close() None[source]
Overview:

删除此数据加载器。首先将end_flag设置为True,这意味着不同的进程/线程将清除并关闭所有数据队列;然后所有进程将被终止并加入。

data.dataset

请参考 ding/utils/data/dataset 获取更多详细信息。

数据集统计

class ding.utils.data.dataset.DatasetStatistics(mean: ndarray, std: ndarray, action_bounds: ndarray)[source]
Overview:

数据集统计。

__init__(mean: ndarray, std: ndarray, action_bounds: ndarray) None
action_bounds: ndarray
mean: ndarray
std: ndarray

NaiveRLDataset

class ding.utils.data.dataset.NaiveRLDataset(cfg)[source]
Overview:

Naive RL 数据集,用于离线 RL 算法。

Interfaces:

__init__, __len__, __getitem__

__getitem__(idx: int) Dict[str, Tensor][来源]
Overview:

获取数据集的项。

__init__(cfg) None[源代码]
Overview:

初始化方法。

Arguments:
  • cfg (dict): 配置字典。

__len__() int[来源]
Overview:

获取数据集的长度。

D4RL数据集

class ding.utils.data.dataset.D4RLDataset(cfg: dict)[source]
Overview:

D4RL数据集,用于离线强化学习算法。

Interfaces:

__init__, __len__, __getitem__

Properties:
  • 平均值 (np.ndarray): 数据集的平均值。

  • 标准差 (np.ndarray): 数据集的标准差。

  • action_bounds (np.ndarray): 数据集的动作边界。

  • 统计信息 (dict): 数据集的统计信息。

__getitem__(idx: int) Dict[str, Tensor][来源]
Overview:

获取数据集的项。

__init__(cfg: dict) None[源代码]
Overview:

初始化方法。

Arguments:
  • cfg (dict): 配置字典。

__len__() int[source]
Overview:

获取数据集的长度。

_cal_statistics(dataset, env, eps=0.001, add_action_buffer=True)[来源]
Overview:

计算数据集的统计信息。

Arguments:
  • 数据集 (Dict[str, np.ndarray]): d4rl 数据集。

  • 环境 (gym.Env): 环境。

  • eps (float): 误差值。

_load_d4rl(dataset: Dict[str, ndarray]) None[来源]
Overview:

加载d4rl数据集。

Arguments:
  • 数据集 (Dict[str, np.ndarray]): d4rl 数据集。

_normalize_states(dataset)[source]
Overview:

标准化状态。

Arguments:
  • 数据集 (Dict[str, np.ndarray]): d4rl 数据集。

property action_bounds: ndarray
Overview:

获取数据集的操作范围。

property data: List
property mean
Overview:

获取数据集的平均值。

property statistics: dict
Overview:

获取数据集的统计信息。

property std
Overview:

获取数据集的标准差。

HDF5数据集

class ding.utils.data.dataset.HDF5Dataset(cfg: dict)[source]
Overview:

HDF5数据集以hdf5格式保存,用于离线RL算法。 hdf5格式是Python中存储大型数值数组的常见格式。 有关更多详细信息,请参阅https://support.hdfgroup.org/HDF5/

Interfaces:

__init__, __len__, __getitem__

Properties:
  • 平均值 (np.ndarray): 数据集的平均值。

  • 标准差 (np.ndarray): 数据集的标准差。

  • action_bounds (np.ndarray): 数据集的动作边界。

  • 统计信息 (dict): 数据集的统计信息。

__getitem__(idx: int) Dict[str, Tensor][source]
Overview:

获取数据集的项。

Arguments:
  • idx (int): 数据集的索引。

__init__(cfg: dict) None[源代码]
Overview:

初始化方法。

Arguments:
  • cfg (dict): 配置字典。

__len__() int[source]
Overview:

获取数据集的长度。

_cal_statistics(eps: float = 0.001)[source]
Overview:

计算数据集的统计信息。

Arguments:
  • eps (float): 误差值。

_load_data(dataset: Dict[str, ndarray]) None[source]
Overview:

加载数据集。

Arguments:
  • 数据集 (Dict[str, np.ndarray]): 数据集。

_normalize_states()[source]
Overview:

标准化状态。

property action_bounds: ndarray
Overview:

获取数据集的操作范围。

property mean
Overview:

获取数据集的平均值。

property statistics: dict
Overview:

获取数据集的统计信息。

property std
Overview:

获取数据集的标准差。

D4RL轨迹数据集

class ding.utils.data.dataset.D4RLTrajectoryDataset(cfg: dict)[来源]
Overview:

D4RL轨迹数据集,用于离线RL算法。

Interfaces:

__init__, __len__, __getitem__

D4RL_DATASET_STATS = {'halfcheetah-medium-expert-v2': {'state_mean': [-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, -0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, 8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, 0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314], 'state_std': [0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, 0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, 1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, 6.4811787605285645, 6.378620147705078]}, 'halfcheetah-medium-replay-v2': {'state_mean': [-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, -0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, 3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, -0.015839405357837677], 'state_std': [0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, 0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, 1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, 6.085654258728027, 7.25300407409668]}, 'halfcheetah-medium-v2': {'state_mean': [-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, -0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, 5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, 0.013382787816226482], 'state_std': [0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, 0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, 1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, 5.671932697296143, 7.4982590675354]}, 'hopper-medium-expert-v2': {'state_mean': [1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, 0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, -0.1766270101070404, -0.11862941086292267, -0.12097819894552231], 'state_std': [0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, 0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, 5.725032806396484]}, 'hopper-medium-replay-v2': {'state_mean': [1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, 0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, -0.5287045240402222, -0.14465883374214172, -0.19652697443962097], 'state_std': [0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, 1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, 5.108601093292236]}, 'hopper-medium-v2': {'state_mean': [1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, 2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, -0.18540096282958984, -0.28461286425590515], 'state_std': [0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, 0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, 5.607253551483154]}, 'walker2d-medium-expert-v2': {'state_mean': [1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, 0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, 3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, -0.27366524934768677], 'state_std': [0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, 0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, 1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, 4.039782524108887, 5.891613960266113]}, 'walker2d-medium-replay-v2': {'state_mean': [1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, -0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, -0.08934258669614792, -0.2992438077926636, -0.5984178185462952], 'state_std': [0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, 0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, 2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, 3.845186948776245, 5.4768385887146]}, 'walker2d-medium-v2': {'state_mean': [1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, -0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, 0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654], 'state_std': [0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, 0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, 1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, 3.7445690631866455, 5.5851287841796875]}}
REF_MAX_SCORE = {'halfcheetah': 12135.0, 'hopper': 3234.3, 'walker2d': 4592.3}
REF_MIN_SCORE = {'halfcheetah': -280.178953, 'hopper': -20.272305, 'walker2d': 1.629008}
__getitem__(idx: int) Tuple[Tensor, Tensor, Tensor, Tensor, Tensor][source]
Overview:

获取数据集的项。

Arguments:
  • idx (int): 数据集的索引。

__init__(cfg: dict) None[源代码]
Overview:

初始化方法。

Arguments:
  • cfg (dict): 配置字典。

__len__() int[源代码]
Overview:

获取数据集的长度。

get_d4rl_dataset_stats(env_d4rl_name: str) Dict[str, list][source]
Overview:

获取d4rl数据集统计信息。

Arguments:
  • env_d4rl_name (str): d4rl环境的名称。

get_max_timestep() int[来源]
Overview:

获取数据集的最大时间步长。

get_state_stats() Tuple[ndarray, ndarray][来源]
Overview:

获取数据集的状态均值和标准差。

D4RLDiffuserDataset

class ding.utils.data.dataset.D4RLDiffuserDataset(dataset_path: str, context_len: int, rtg_scale: float)[来源]
Overview:

D4RL扩散器数据集,用于离线RL算法。

Interfaces:

__init__, __len__, __getitem__

__init__(dataset_path: str, context_len: int, rtg_scale: float) None[source]
Overview:

D4RLDiffuserDataset的初始化方法。

Arguments:
  • dataset_path (str): 数据集路径。

  • context_len (int): 上下文的长度。

  • rtg_scale (float): 返回的规模。

固定回放缓冲区

class ding.utils.data.dataset.FixedReplayBuffer(data_dir: str, replay_suffix: int, *args, **kwargs)[source]
Overview:

由一系列OutofGraphReplayBuffers组成的对象。

Interfaces:

__init__, get_transition_elements, sample_transition_batch

__init__(data_dir: str, replay_suffix: int, *args, **kwargs)[source]
Overview:

初始化 FixedReplayBuffer 类。

Arguments:
  • data_dir (str): 从中加载回放缓冲区的日志目录。

  • replay_suffix (int): 如果不为None,则仅加载数据目录中与特定后缀对应的回放缓冲区。

  • args (list): 任意额外的参数。

  • kwargs (dict): 任意关键字参数。

_load_buffer(suffix)[source]
Overview:

加载一个OutOfGraphReplayBuffer回放缓冲区。

Arguments:
  • 后缀 (int): 回放缓冲区的后缀。

get_transition_elements()[来源]
Overview:

返回过渡元素。

load_single_buffer(suffix)[source]
Overview:

加载单个回放缓冲区。

Arguments:
  • 后缀 (int): 回放缓冲区的后缀。

sample_transition_batch(batch_size=None, indices=None)[来源]
Overview:

返回一批转换(包括任何额外内容)。

Arguments:
  • batch_size (int): 批量大小。

  • 索引 (list): 批次的索引。

PCDataset

class ding.utils.data.dataset.PCDataset(all_data)[source]
Overview:

用于程序克隆的数据集。

Interfaces:

__init__, __len__, __getitem__

__getitem__(item)[source]
Overview:

获取数据集的项。

Arguments:
  • 项目 (int): 数据集的索引。

__init__(all_data)[来源]
Overview:

PCDataset的初始化方法。

Arguments:
  • all_data (tuple): 所有数据的元组。

__len__()[source]
Overview:

获取数据集的长度。

加载BFS数据集

ding.utils.data.dataset.load_bfs_datasets(train_seeds=1, test_seeds=5)[来源]
Overview:

加载BFS数据集。

Arguments:
  • train_seeds (int): 训练种子的数量。

  • test_seeds (int): 测试种子的数量。

BCO数据集

class ding.utils.data.dataset.BCODataset(data=None)[source]
Overview:

用于从观察中进行行为克隆的数据集。

Interfaces:

__init__, __len__, __getitem__

Properties:
  • obs (np.ndarray): 观测数组。

  • 动作 (np.ndarray): 动作数组。

__getitem__(idx)[来源]
Overview:

获取数据集的项。

Arguments:
  • idx (int): 数据集的索引。

__init__(data=None)[源代码]
Overview:

BCODataset的初始化方法。

Arguments:
  • 数据 (dict): 数据字典。

__len__()[来源]
Overview:

获取数据集的长度。

property action
Overview:

获取操作数组。

property obs
Overview:

获取观测数组。

序列数据集

class ding.utils.data.dataset.SequenceDataset(cfg)[source]
Overview:

扩散器的数据集。

Interfaces:

__init__, __len__, __getitem__

__getitem__(idx, eps=0.0001)[来源]
Overview:

获取数据集的项。

Arguments:
  • idx (int): 数据集的索引。

  • eps (float): 该epsilon值。

__init__(cfg)[源代码]
Overview:

SequenceDataset的初始化方法。

Arguments:
  • cfg (dict): 配置字典。

__len__()[来源]
Overview:

获取数据集的长度。

_get_bounds()[source]
Overview:

获取数据集的边界。

get_conditions(observations)[source]
Overview:

获取当前观察的条件以进行规划。

Arguments:
  • 观测值 (np.ndarray): 观测数组。

make_indices(path_lengths, horizon)[source]
Overview:

为从数据集中采样创建索引。每个索引映射到一个数据点。

Arguments:
  • 路径长度数组 (np.ndarray): 路径长度数组。

  • horizon (int): 时间范围。

maze2d_set_terminals(env, dataset)[source]
Overview:

设置maze2d的终端。

Arguments:
  • 环境 (gym.Env): 健身房环境。

  • 数据集 (dict): 数据集字典。

normalize(keys=['observations', 'actions'])[来源]
Overview:

标准化数据集,标准化将由扩散模型预测的字段

Arguments:
  • keys (list): 键的列表。

normalize_value(value)[source]
Overview:

标准化该值。

Arguments:
  • 值 (np.ndarray): 值数组。

process_maze2d_episode(episode)[source]
Overview:

处理迷宫2D片段,将next_observations字段添加到片段中。

Arguments:
  • 剧集 (dict): 剧集字典。

sequence_dataset(env, dataset=None)[source]
Overview:

对数据集进行排序。

Arguments:
  • 环境 (gym.Env): 健身房环境。

hdf5保存

ding.utils.data.dataset.hdf5_save(exp_data, expert_data_path)[source]
Overview:

将数据保存到hdf5。

naive_save

ding.utils.data.dataset.naive_save(exp_data, expert_data_path)[source]
Overview:

将数据保存到pickle。

离线数据保存类型

ding.utils.data.dataset.offline_data_save_type(exp_data, expert_data_path, data_type='naive')[source]
Overview:

保存离线数据。

创建数据集

ding.utils.data.dataset.create_dataset(cfg, **kwargs) Dataset[source]
Overview:

创建数据集。

bfs_helper

请参考 ding/utils/bfs_helper 获取更多详细信息。

获取VI序列

ding.utils.bfs_helper.get_vi_sequence(env: Env, observation: ndarray) Tuple[ndarray, List][源代码]
Overview:

给定迷宫环境的一个实例和当前的观察结果,使用广度优先搜索(BFS)算法规划一条最优路径并记录结果。

Arguments:
  • 环境 (Env): 迷宫环境的实例。

  • 观察值 (np.ndarray): 当前的观察值。

Returns:
  • 输出 (Tuple[np.ndarray, List]): BFS 的结果。 output[0] 包含每次迭代后的 BFS 地图,output[1] 包含到达终点前的最优动作。

collection_helper

请参考 ding/utils/collection_helper 获取更多详细信息。

iter_mapping

ding.utils.collection_helper.iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType])[source]
Overview:

将可迭代元素的列表映射到输入迭代可调用

Arguments:
  • iter_(_IterType list): 用于迭代的列表

  • 映射 (Callable [[_IterType], _IterTargetType]): 一个可调用的函数,用于映射可迭代元素。

Return:
  • (iter_mapping object): 迭代结果

Example:
>>> iterable_list = [1, 2, 3, 4, 5]
>>> _iter = iter_mapping(iterable_list, lambda x: x ** 2)
>>> print(list(_iter))
[1, 4, 9, 16, 25]

压缩助手

请参考 ding/utils/compression_helper 获取更多详细信息。

CloudPickleWrapper

class ding.utils.compression_helper.CloudPickleWrapper(data: Any)[source]
Overview:

CloudPickleWrapper 能够序列化更多的 Python 对象(例如:带有 lambda 表达式的对象)。

Interfaces:

__init__, __getstate__, __setstate__.

__init__(data: Any) None[source]
Overview:

使用给定的参数初始化 CloudPickleWrapper。

Arguments:
  • data (Any): 要转储的对象。

虚拟压缩器

ding.utils.compression_helper.dummy_compressor(data: Any) Any[source]
Overview:

返回原始输入数据。

Arguments:
  • data (Any): 压缩器的输入数据。

Returns:
  • 输出 (Any): 此压缩器将准确返回输入数据。

zlib数据压缩器

ding.utils.compression_helper.zlib_data_compressor(data: Any) bytes[source]
Overview:

获取输入的压缩数据并返回压缩的原始数据(zlib 压缩器)以二进制格式。

Arguments:
  • data (Any): 压缩器的输入数据。

Returns:
  • 输出 (bytes): 压缩后的字节类结果。

Examples:
>>> zlib_data_compressor("Hello")

lz4数据压缩器

ding.utils.compression_helper.lz4_data_compressor(data: Any) bytes[source]
Overview:

返回压缩的原始数据(lz4 压缩器)。压缩器以二进制格式输出。

Arguments:
  • data (Any): 压缩器的输入数据。

Returns:
  • 输出 (bytes): 压缩后的字节类结果。

Examples:
>>> lz4.block.compress(pickle.dumps("Hello"))
b'R€•      ŒHello”.'

jpeg_data_compressor

ding.utils.compression_helper.jpeg_data_compressor(data: ndarray) bytes[source]
Overview:

为了减少内存使用,我们可以选择在缓冲区中存储图像的jpeg字符串而不是numpy数组。此函数将观察的numpy数组编码为jpeg字符串。

Arguments:
  • 数据 (np.array): 观察的numpy数组。

Returns:
  • img_str (bytes): 压缩后的字节类结果。

获取数据压缩器

ding.utils.compression_helper.get_data_compressor(name: str)[source]
Overview:

根据输入名称获取数据压缩器。

Arguments:
  • name(str): 压缩器的名称,支持 ['lz4', 'zlib', 'jpeg', 'none']

Return:
  • compressor (Callable): 对应的数据压缩器,接收输入数据并返回压缩后的数据。

Example:
>>> compress_fn = get_data_compressor('lz4')
>>> compressed_data = compressed(input_data)

dummy_decompressor

ding.utils.compression_helper.dummy_decompressor(data: Any) Any[source]
Overview:

返回输入数据。

Arguments:
  • data (Any): 解压缩器的输入数据。

Returns:
  • 输出 (bytes): 解压缩的结果,与输入完全相同。

lz4数据解压缩器

ding.utils.compression_helper.lz4_data_decompressor(compressed_data: bytes) Any[source]
Overview:

返回解压缩的原始数据(lz4 压缩器)。

Arguments:
  • 数据 (bytes): 解压缩器的输入数据。

Returns:
  • 输出 (Any): 解压缩后的对象。

zlib数据解压缩器

ding.utils.compression_helper.zlib_data_decompressor(compressed_data: bytes) Any[source]
Overview:

返回解压缩的原始数据(zlib 压缩器)。

Arguments:
  • 数据 (bytes): 解压缩器的输入数据。

Returns:
  • 输出 (Any): 解压缩后的对象。

jpeg_data_decompressor

ding.utils.compression_helper.jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) ndarray[source]
Overview:

为了减少内存使用,我们可以选择在缓冲区中存储图像的jpeg字符串而不是numpy数组。此函数从jpeg字符串解码观察numpy数组。

Arguments:
  • 压缩数据 (bytes): jpeg字符串。

  • gray_scale (bool): If the observation is gray, gray_scale=True,

    如果观察是RGB,gray_scale=False

Returns:
  • arr (np.ndarray): 解压后的numpy数组。

获取数据解压缩器

ding.utils.compression_helper.get_data_decompressor(name: str) Callable[source]
Overview:

根据输入名称获取数据解压缩器。

Arguments:
  • name(str): 解压缩器的名称,支持 ['lz4', 'zlib', 'none']

注意

对于所有的解压缩器,都需要输入一个类似字节的对象。

Returns:
  • 解压缩器 (Callable): 对应的数据解压缩器。

Examples:
>>> decompress_fn = get_data_decompressor('lz4')
>>> origin_data = compressed(compressed_data)

default_helper

请参考 ding/utils/default_helper 获取更多详细信息。

获取形状0

ding.utils.default_helper.get_shape0(data: List | Dict | Tensor | Tensor) int[source]
Overview:

获取数据的torch张量或treetensor的形状[0]

Arguments:
  • 数据 (Union[List,Dict,torch.Tensor,ttorch.Tensor]): 需要分析的数据

Returns:
  • shape[0] (int): 数据的第一个维度长度,通常是批量大小。

列表转字典

ding.utils.default_helper.lists_to_dicts(data: List[dict | NamedTuple] | Tuple[dict | NamedTuple], recursive: bool = False) Mapping[object, object] | NamedTuple[source]
Overview:

将字典列表转换为列表字典。

Arguments:
  • data (Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]):

    需要转换一个列表的字典

  • recursive (bool): 是否递归处理字典元素

Returns:
  • newdata (Union[Mapping[object, object], NamedTuple]): 作为结果的字典列表

Example:
>>> from ding.utils import *
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
{1: [1, 2], 10: [3, 4]}

字典转列表

ding.utils.default_helper.dicts_to_lists(data: Mapping[object, List[object]]) List[Mapping[object, object]][source]
Overview:

将一个列表的字典转换为字典的列表。

Arguments:
  • data (Mapping[object, list]): 需要转换的字典列表

Returns:
  • newdata (List[Mapping[object, object]]): 作为结果的字典列表

Example:
>>> from ding.utils import *
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
[{1: 1, 10: 3}, {1: 2, 10: 4}]

覆盖

ding.utils.default_helper.override(cls: type) Callable[[Callable], Callable][source]
Overview:

用于记录方法覆盖的注解。

Arguments:
  • cls (type): The superclass that provides the overridden method. If this

    cls 实际上没有该方法,会引发错误。

压缩

ding.utils.default_helper.squeeze(data: object) object[source]
Overview:

从元组、列表或字典中提取数据到单个对象

Arguments:
  • 数据 (object): 要被压缩的数据

Example:
>>> a = (4, )
>>> a = squeeze(a)
>>> print(a)
>>> 4

default_get

ding.utils.default_helper.default_get(data: dict, name: str, default_value: Any | None = None, default_fn: Callable | None = None, judge_fn: Callable | None = None) Any[source]
Overview:

通过输入获取值,通常检查至少包含dataname的输入。如果name存在于data中,则获取name处的值;否则,将name添加到default_get_set中,其值由default_fn生成(或直接作为default_value),并通过``judge_fn``检查其合法性。

Arguments:
  • data(dict): 数据输入字典

  • 名称(str): 键名

  • 默认值(Optional[Any]) = None,

  • default_fn(Optional[Callable]) = 值

  • judge_fn(Optional[Callable]) = None

Returns:
  • ret(list): 分割后的数据

  • 残差(list): 残差列表

列表分割

ding.utils.default_helper.list_split(data: list, step: int) List[list][source]
Overview:

按步骤分割数据列表。

Arguments:
  • data(list): 用于分割的数据列表

  • 步骤(int): 分割的步骤数

Returns:
  • ret(list): 分割后的数据列表。

  • 残差(list): 残差列表。当data除以steps时,此值为None

Example:
>>> list_split([1,2,3,4],2)
([[1, 2], [3, 4]], None)
>>> list_split([1,2,3,4],3)
([[1, 2, 3]], [4])

error_wrapper

ding.utils.default_helper.error_wrapper(fn, default_ret, warning_msg='')[source]
Overview:

包装函数,以便捕获函数中的任何异常并返回 default_ret

Arguments:
  • fn (Callable): 要包装的函数

  • default_ret (obj): 函数中发生异常时的默认返回值

Returns:
  • wrapper (Callable): 被包装的函数

Examples:
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
>>> def get_rank():  # Get the rank of linklink model, return 0 if use FakeLink.
>>>    if is_fake_link:
>>>        return 0
>>>    return error_wrapper(link.get_rank, 0)()

有限空间容器

class ding.utils.default_helper.LimitedSpaceContainer(min_val: int, max_val: int)[source]
Overview:

一个太空模拟器。

Interfaces:

__init__, get_residual_space, release_space

__init__(min_val: int, max_val: int) None[source]
Overview:

设置容器的min_valmax_val,并将cur设置为min_val以进行初始化。

Arguments:
  • min_val (int): 容器的最小容量,通常为0。

  • max_val (int): 容器的最大容量。

acquire_space() bool[source]
Overview:

尝试获取一块空间。如果有,返回True;否则返回False。

Returns:
  • flag (bool): 是否存在任何残留空间。

decrease_space() None[source]
Overview:

减少空间中的一块。递减 max_val

get_residual_space() int[源代码]
Overview:

获取所有剩余的空间片段。将 cur 设置为 max_val

Arguments:
  • ret (int): 剩余空间,由 max_val - cur 计算得出。

increase_space() None[source]
Overview:

在空间中增加一块。增加 max_val

release_space() None[source]
Overview:

仅释放一块空间。递减cur,但确保它不会变为负数。

深度合并字典

ding.utils.default_helper.deep_merge_dicts(original: dict, new_dict: dict) dict[source]
Overview:

通过调用deep_update合并两个字典

Arguments:
  • 原始 (dict): 字典 1.

  • new_dict (dict): 字典 2.

Returns:
  • merged_dict (dict): 一个新的字典,它是d1和d2深度合并的结果。

deep_update

ding.utils.default_helper.deep_update(original: dict, new_dict: dict, new_keys_allowed: bool = False, whitelist: List[str] | None = None, override_all_if_type_changes: List[str] | None = None)[source]
Overview:

使用来自new_dict的值递归更新原始字典。

Arguments:
  • 原始 (dict): 带有默认值的字典。

  • new_dict (dict): 包含要更新的值的字典

  • new_keys_allowed (bool): 是否允许新密钥。

  • whitelist (Optional[List[str]]):

    对应于字典值的键列表,可以在其中引入新的子键。这仅在顶层有效。

  • override_all_if_type_changes(Optional[List[str]]):

    顶级键列表,其值为字典,如果该值字典中的“type”键发生变化,我们总是简单地覆盖整个值(dict)。

注意

如果在新字典中引入了新键,那么如果 new_keys_allowed 不为 True,则会抛出错误。此外,对于子字典,如果键在白名单中,则可以引入新的子键。

扁平化字典

ding.utils.default_helper.flatten_dict(data: dict, delimiter: str = '/') dict[source]
Overview:

将字典展平,参见示例

Arguments:
  • 数据 (dict): 原始嵌套字典

  • delimiter (str): 新字典键的分隔符

Returns:
  • 数据 (dict): 展平的嵌套字典

Example:
>>> a
{'a': {'b': 100}}
>>> flatten_dict(a)
{'a/b': 100}

设置包种子

ding.utils.default_helper.set_pkg_seed(seed: int, use_cuda: bool = True) None[source]
Overview:

副作用函数用于为randomnumpy randomtorch's manual seed设置种子。这通常用于入口脚本中,为所有包和实例设置随机种子的部分。

Argument:
  • seed(int): 设置种子

  • use_cuda(bool) 是否使用cuda

Examples:
>>> # ../entry/xxxenv_xxxpolicy_main.py
>>> ...
# Set random seed for all package and instance
>>> collector_env.seed(seed)
>>> evaluator_env.seed(seed, dynamic_seed=False)
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
>>> ...
# Set up RL Policy, etc.
>>> ...

一次性警告

ding.utils.default_helper.one_time_warning(warning_msg: str) None[source]
Overview:

仅打印一次警告信息。

Arguments:
  • warning_msg (str): 警告信息。

split_fn

ding.utils.default_helper.split_fn(data, indices, start, end)[源代码]
Overview:

按索引分割数据

Arguments:
  • 数据 (Union[List, Dict, torch.Tensor, ttorch.Tensor]): 要分析的数据

  • 索引 (np.ndarray): 要分割的索引

  • 开始 (int): 起始索引

  • 结束 (int): 结束索引

split_data_generator

ding.utils.default_helper.split_data_generator(data: dict, split_size: int, shuffle: bool = True) dict[source]
Overview:

将数据分成批次

Arguments:
  • 数据 (dict): 需要分析的数据

  • split_size (int): 分割大小

  • shuffle (bool): 是否打乱

RunningMeanStd

class ding.utils.default_helper.RunningMeanStd(epsilon=0.0001, shape=(), device=device(type='cpu'))[source]
Overview:

用于更新新变量、新均值和新计数的包装器

Interfaces:

__init__, update, reset, new_shape

Properties:
  • mean, std, _epsilon, _shape, _mean, _var, _count

__init__(epsilon=0.0001, shape=(), device=device(type='cpu'))[源代码]
Overview:

初始化 self. 请参阅 help(type(self)) 以获取准确的签名;设置属性。

Arguments:
  • env (gym.Env): 要包装的环境。

  • epsilon (Float): 用于自标准输出的epsilon

  • 形状 (:obj: np.array): 用于表示此包装器在均值和方差属性上的np数组形状

property mean: ndarray
Overview:

属性 meanself._mean 获取

static new_shape(obs_shape, act_shape, rew_shape)[source]
Overview:

获取观察、动作和奖励的新形状;在这种情况下保持不变。

Arguments:

obs_shape (Any), act_shape (Any), rew_shape (Any)

Returns:

观测形状 (Any), 动作形状 (Any), 奖励形状 (Any)

reset()[source]
Overview:

重置环境的状态和重置属性:_mean, _var, _count

property std: ndarray
Overview:

属性 stdself._varself._epsilon 的 epsilon 值计算得出

update(x)[source]
Overview:

更新均值、变量和计数

Arguments:
  • x: 批次

make_key_as_identifier

ding.utils.default_helper.make_key_as_identifier(data: Dict[str, Any]) Dict[str, Any][source]
Overview:

将字典的键转换为合法的Python标识符字符串,以便它与一些Python魔术方法兼容,例如__getattr

Arguments:
  • data (Dict[str, Any]): 原始字典数据。

Return:
  • new_data (Dict[str, Any]): 具有合法标识符键的新字典数据。

remove_illegal_item

ding.utils.default_helper.remove_illegal_item(data: Dict[str, Any]) Dict[str, Any][source]
Overview:

移除字典信息中的非法项,如字符串,这与Tensor不兼容。

Arguments:
  • 数据 (Dict[str, Any]): 原始的字典数据。

Return:
  • new_data (Dict[str, Any]): 不包含合法项的新字典数据。

设计助手

请参考 ding/utils/design_helper 获取更多详细信息。

单例元类

class ding.utils.design_helper.SingletonMetaclass(name, bases, namespace, **kwargs)[source]
Overview:

返回输入类中的给定类型实例

Interfaces:

__call__

instances = {<class 'ding.framework.parallel.Parallel'>: <ding.framework.parallel.Parallel object>}

fast_copy

请参考 ding/utils/fast_copy 获取更多详细信息。

_快速复制

class ding.utils.fast_copy._FastCopy[source]
Overview:

这个类的想法来自于这篇文章 https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list。 我们使用递归调用来复制每个需要复制的对象,这将比 copy.deepcopy 快5倍。

Interfaces:

__init__, _copy_list, _copy_dict, _copy_tensor, _copy_ndarray, copy.

__init__()[源代码]
Overview:

初始化 _FastCopy 对象。

_copy_dict(d: dict) dict[source]
Overview:

复制字典。

Arguments:
  • d (dict): 要复制的字典。

_copy_list(l: List) dict[source]
Overview:

复制列表。

Arguments:
  • l (List): 要复制的列表。

_copy_ndarray(a: ndarray) ndarray[source]
Overview:

复制ndarray。

Arguments:
  • a (np.ndarray): 要复制的ndarray。

_copy_tensor(t: Tensor) Tensor[source]
Overview:

复制张量。

Arguments:
  • t (torch.Tensor): 要复制的张量。

copy(sth: Any) Any[source]
Overview:

复制对象。

Arguments:
  • sth (Any): 要复制的对象。

文件助手

请参考 ding/utils/file_helper 获取更多详细信息。

从ceph读取

ding.utils.file_helper.read_from_ceph(path: str) object[source]
Overview:

从ceph读取文件

Arguments:
  • 路径 (str): 在ceph中的文件路径,以 "s3://" 开头

Returns:
  • (data): 反序列化的数据

_get_redis

ding.utils.file_helper._get_redis(host='localhost', port=6379)[source]
Overview:

确保Redis的使用

Arguments:
  • 主机 (str): 主机字符串

  • 端口 (int): 端口号

Returns:
  • (Redis(object)): 具有给定 host, port, 和 db=0 的 Redis 对象

从Redis读取

ding.utils.file_helper.read_from_redis(path: str) object[source]
Overview:

从redis读取文件

Arguments:
  • 路径 (str): Redis中的文件路径,可以是一个字符串键

Returns:
  • (data): 反序列化的数据

_ensure_rediscluster

ding.utils.file_helper._ensure_rediscluster(startup_nodes=[{'host': '127.0.0.1', 'port': '7000'}])[source]
Overview:

确保Redis的使用

Arguments:
  • List of startup nodes (dict) of
    • 主机 (str): 主机字符串

    • 端口号 (int): 端口号

Returns:
  • (RedisCluster(object)): 具有给定 host, port, 并且默认情况下 decode_responsesFalse 的 RedisCluster 对象。

从Redis集群读取

ding.utils.file_helper.read_from_rediscluster(path: str) object[source]
Overview:

从rediscluster读取文件

Arguments:
  • 路径 (str): Redis集群中的文件路径,可以是一个字符串键

Returns:
  • (data): 反序列化的数据

从文件中读取

ding.utils.file_helper.read_from_file(path: str) object[source]
Overview:

从本地文件系统读取文件

Arguments:
  • 路径 (str): 本地文件系统中的文件路径

Returns:
  • (data): 反序列化的数据

_ensure_memcached

ding.utils.file_helper._ensure_memcached()[source]
Overview:

确保使用memcache

Returns:
  • (MemcachedClient instance): 使用当前的 memcached_client 的 server_list.confclient.conf 文件构建的 MemcachedClient 类实例

从mc读取

ding.utils.file_helper.read_from_mc(path: str, flush=False) object[来源]
Overview:

从memcache读取文件,文件必须由torch.save()保存

Arguments:
  • 路径 (str): 本地系统中的文件路径

Returns:
  • (data): 反序列化的数据

从路径读取

ding.utils.file_helper.read_from_path(path: str)[source]
Overview:

从ceph读取文件

Arguments:
  • 路径 (str): 在ceph中的文件路径,以"s3://"开头,或使用本地文件系统

Returns:
  • (data): 反序列化的数据

保存文件到Ceph

ding.utils.file_helper.save_file_ceph(path, data)[source]
Overview:

将pickle转储的数据文件保存到ceph

Arguments:
  • 路径 (str): 在ceph中的文件路径,以"s3://"开头,否则使用文件系统

  • 数据 (Any): 可以是字典、列表或张量等。

保存文件到Redis

ding.utils.file_helper.save_file_redis(path, data)[source]
Overview:

将pickle转储的数据文件保存到redis

Arguments:
  • 路径 (str): Redis中的文件路径(可以是字符串键)

  • 数据 (Any): 可以是字典、列表或张量等。

保存文件到Redis集群

ding.utils.file_helper.save_file_rediscluster(path, data)[source]
Overview:

将pickle转储的数据文件保存到rediscluster

Arguments:
  • 路径 (str): Redis中的文件路径(可以是字符串键)

  • 数据 (Any): 可以是字典、列表或张量等。

读取文件

ding.utils.file_helper.read_file(path: str, fs_type: str | None = None, use_lock: bool = False) object[source]
Overview:

从路径读取文件

Arguments:
  • 路径 (str): 要读取的文件的路径

  • fs_type (strNone): 文件系统类型,支持 {'normal', 'ceph'}

  • use_lock (bool): 是否在本地普通文件系统中使用 use_lock

保存文件

ding.utils.file_helper.save_file(path: str, data: object, fs_type: str | None = None, use_lock: bool = False) None[source]
Overview:

将数据保存到路径的文件中

Arguments:
  • 路径 (str): 要保存文件的路径

  • 数据 (object): 要保存的数据

  • fs_type (strNone): 文件系统类型,支持 {'normal', 'ceph'}

  • use_lock (bool): 是否在本地普通文件系统中使用 use_lock

remove_file

ding.utils.file_helper.remove_file(path: str, fs_type: str | None = None) None[source]
Overview:

删除文件

Arguments:
  • 路径 (str): 你想要删除的文件的路径

  • fs_type (strNone): 文件系统类型,支持 {'normal', 'ceph'}

import_helper

请参考 ding/utils/import_helper 获取更多详细信息。

尝试导入Ceph

ding.utils.import_helper.try_import_ceph()[源代码]
Overview:

尝试导入ceph模块,如果失败,返回None

Returns:
  • (Module): 导入的模块,或当未找到ceph时为None

尝试导入mc

ding.utils.import_helper.try_import_mc()[source]
Overview:

尝试导入mc模块,如果失败,返回None

Returns:
  • (Module): 导入的模块,或当未找到mc时为None

尝试导入Redis

ding.utils.import_helper.try_import_redis()[source]
Overview:

尝试导入redis模块,如果失败,返回None

Returns:
  • (Module): 导入的模块,当未找到redis时为None

尝试导入rediscluster

ding.utils.import_helper.try_import_rediscluster()[源代码]
Overview:

尝试导入rediscluster模块,如果失败,返回None

Returns:
  • (Module): 导入的模块,当未找到rediscluster时为None

import_module

ding.utils.import_helper.import_module(modules: List[str]) None[source]
Overview:

将多个模块作为列表导入

Arguments:
  • (str list): 模块名称列表

k8s_helper

请参考 ding/utils/k8s_helper 获取更多详细信息。

get_operator_server_kwargs

ding.utils.k8s_helper.get_operator_server_kwargs(cfg: EasyDict) dict[source]
Overview:

从配置文件中获取关键字参数字典

Arguments:
  • cfg (EasyDict) 系统配置

Returns:
  • 结果 (dict) 包含 api_version, namespace, name, port, host.

exist_operator_server

ding.utils.k8s_helper.exist_operator_server() bool[source]
Overview:

检查是否存在‘KUBERNETES_SERVER_URL’环境变量。

pod_exec_command

ding.utils.k8s_helper.pod_exec_command(kubeconfig: str, name: str, namespace: str, cmd: str) Tuple[int, str][source]
Overview:

在pod中执行命令

Arguments:
  • kubeconfig (str) kubeconfig文件的路径

  • 名称 (str) pod的名称

  • 命名空间 (str) pod的命名空间

K8s类型

class ding.utils.k8s_helper.K8sType(value)[source]

一个枚举。

K3s = 2
Local = 1

K8s启动器

class ding.utils.k8s_helper.K8sLauncher(config_path: str)[源代码]
Overview:

管理K8s集群的对象

Interfaces:

__init__, _load, create_cluster, _check_k3d_tools, delete_cluster, preload_images

__init__(config_path: str) None[source]
Overview:

初始化K8sLauncher对象。

Arguments:
  • config_path (str): 配置文件的路径。

_check_k3d_tools() None[source]
Overview:

检查k3d工具是否存在。

_load(config_path: str) None[source]
Overview:

加载配置文件。

Arguments:
  • config_path (str): 配置文件的路径。

create_cluster() None[source]
Overview:

创建k8s集群。

delete_cluster() None[source]
Overview:

删除k8s集群。

preload_images(images: list) None[source]
Overview:

预加载图片。

lock_helper

请参考 ding/utils/lock_helper 获取更多详细信息。

LockContextType

class ding.utils.lock_helper.LockContextType(value)[source]
Overview:

枚举以表达锁的类型。

PROCESS_LOCK = 2
THREAD_LOCK = 1

LockContext

class ding.utils.lock_helper.LockContext(lock_type: LockContextType = LockContextType.THREAD_LOCK)[source]
Overview:

生成一个LockContext以确保线程安全。

Interfaces:

__init__, __enter__, __exit__.

Example:
>>> with LockContext() as lock:
>>>     print("Do something here.")
__init__(lock_type: LockContextType = LockContextType.THREAD_LOCK)[source]
Overview:

根据给定的类型初始化锁。

Arguments:
  • lock_type (LockContextType): 要使用的锁类型。默认为 LockContextType.THREAD_LOCK。

acquire()[source]
Overview:

获取锁。

release()[source]
Overview:

释放锁。

获取读写文件锁

ding.utils.lock_helper.get_rw_file_lock(name: str, op: str)[source]
Overview:

获取具有名称和操作符的生成文件锁

Arguments:
  • 名称 (str): 锁的名称。

  • op (str): 指定的操作符,即 readwrite

Returns:
  • (RWLockFairD): 生成的读写锁

FcntlContext

class ding.utils.lock_helper.FcntlContext(lock_path: str)[source]
Overview:

一个上下文管理器,使用fcntl在文件上获取独占锁。这对于防止多个进程运行相同的代码非常有用。

Interfaces:

__init__, __enter__, __exit__.

Example:
>>> lock_path = "/path/to/lock/file"
>>> with FcntlContext(lock_path) as lock:
>>>    # Perform operations while the lock is held
__init__(lock_path: str) None[source]
Overview:

初始化LockHelper对象。

Arguments:
  • lock_path (str): 锁文件的路径。

获取文件锁

ding.utils.lock_helper.get_file_lock(name: str, op: str) FcntlContext[source]
Overview:

获取指定文件的文件锁。

Arguments:
  • 名称 (str): 文件的名称。

  • op (str): 对文件锁执行的操作。

日志助手

请参考 ding/utils/log_helper 获取更多详情。

build_logger

ding.utils.log_helper.build_logger(path: str, name: str | None = None, need_tb: bool = True, need_text: bool = True, text_level: int | str = 20) Tuple[Logger | None, SummaryWriter | None][source]
Overview:

构建文本记录器和TensorBoard记录器。

Arguments:
  • 路径 (str): Logger(Textlogger & SummaryWriter) 的保存目录

  • 名称 (str): 日志文件名

  • need_tb (bool): 是否创建并返回 SummaryWriter 实例

  • need_text (bool): 是否创建并返回 loggingLogger 实例

  • text_level (int`str): logging.Logger 的日志级别,默认设置为 logging.INFO

Returns:
  • logger (Optional[logging.Logger]): 显示终端输出的日志记录器

  • tb_logger (Optional['SummaryWriter']): 将输出保存到tfboard,仅在need_tb时返回。

TBLoggerFactory

class ding.utils.log_helper.TBLoggerFactory[source]
Overview:

TBLoggerFactory 是 SummaryWriter 的工厂类。

Interfaces:

create_logger

Properties:
  • tb_loggers (Dict[str, SummaryWriter]): 一个存储 SummaryWriter 实例的字典。

classmethod create_logger(logdir: str) DistributedWriter[source]
tb_loggers = {}

LoggerFactory

class ding.utils.log_helper.LoggerFactory[source]
Overview:

LoggerFactory 是 logging.Logger 的工厂类。

Interfaces:

create_logger, get_tabulate_vars, get_tabulate_vars_hor

classmethod create_logger(path: str, name: str = 'default', level: int | str = 20) Logger[源代码]
Overview:

使用日志记录创建记录器

Arguments:
  • 名称 (str): 日志记录器的名称

  • 路径 (str): 日志保存目录

  • 级别 (intstr): 用于设置级别。参考: Logger.setLevel 方法。

Returns:
  • (logging.Logger): 新的日志记录器

static get_tabulate_vars(variables: Dict[str, Any]) str[source]
Overview:

获取所有变量的文本描述,以表格形式显示

Arguments:
  • 变量 (List[str]): 要查询的变量名称。

Returns:
  • 字符串 (str): 以表格形式描述所有变量的文本

static get_tabulate_vars_hor(variables: Dict[str, Any]) str[source]
Overview:

获取所有变量的文本描述,以表格形式显示

Arguments:
  • 变量 (List[str]): 要查询的变量名称。

pretty_print

ding.utils.log_helper.pretty_print(result: dict, direct_print: bool = True) str[source]
Overview:

以漂亮的方式打印字典 result

Arguments:
  • 结果 (dict): 要打印的结果

  • direct_print (bool): 是否直接打印

Returns:
  • 字符串 (str): 以字符串格式漂亮打印的结果

日志写入助手

请参考 ding/utils/log_writer_helper 获取更多详细信息。

分布式写入器

class ding.utils.log_writer_helper.DistributedWriter(*args, **kwargs)[source]
Overview:

一个简单的SummaryWriter子类,支持在多进程模式下写入一个进程。 最好的方法是与router结合使用,以利用路由器的消息和事件组件(参见writer.plugin)。

Interfaces:

get_instance, plugin, initialize, __del__

__del__()[source]
Overview:

关闭文件写入器。

classmethod get_instance(*args, **kwargs) DistributedWriter[source]
Overview:

获取实例并在第一次调用时设置根级别实例。如果args和kwargs为空,此方法将返回根实例。

Arguments:
  • args (Tuple): 传递给父类 __init__ 函数的参数, SummaryWriter。

  • kwargs (Dict): 传递给父类 __init__ 函数的关键字参数, SummaryWriter。

plugin(router: Parallel, is_writer: bool = False) DistributedWriter[source]
Overview:

插件 router,因此当使用此写入器并激活路由器时,它将自动将请求发送到主写入器,而不是将其写入磁盘。这样我们可以从多个进程中收集数据并将其写入一个文件。

Arguments:
  • 路由器 (Parallel): 要插入的路由器。

  • is_writer (bool): 此写入器是否为主写入器。

Examples:
>>> DistributedWriter().plugin(router, is_writer=True)

enable_parallel

ding.utils.log_writer_helper.enable_parallel(fn_name, fn)[source]
Overview:

用于启用并行写入的装饰器。

Arguments:
  • fn_name (str): 要调用的函数的名称。

  • fn (Callable): 要调用的函数。

normalizer_helper

请参考 ding/utils/normalizer_helper 获取更多详细信息。

数据集标准化器

class ding.utils.normalizer_helper.DatasetNormalizer(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[source]
Overview:

DatasetNormalizer 类提供了对数据集中数据进行归一化和反归一化的功能。 它以数据集作为输入,并对数据集中的每个键应用归一化函数。

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(dataset: ndarray, normalizer: str, path_lengths: list | None = None)[源代码]
Overview:

初始化NormalizerHelper对象。

Arguments:
  • 数据集 (np.ndarray): 需要标准化的数据集。

  • normalizer (str): 要使用的归一化器类型。可以是一个表示归一化器类名称的字符串。

  • 路径长度 (list): 数据集中路径的长度。默认为 None。

normalize(x: ndarray, key: str) ndarray[source]
Overview:

使用指定的键对输入数据进行归一化。

Arguments:
  • x (np.ndarray): 需要标准化的输入数据。

  • key (:obj`str`): 用于标识标准化器的键。

Returns:
  • ret (np.ndarray): 输入数据的归一化值。

unnormalize(x: ndarray, key: str) ndarray[源代码]
Overview:

使用指定的key对给定的值x进行非规范化处理。

Arguments:
  • x (np.ndarray): 需要反归一化的值。

  • key (:obj`str`): 用于标识标准化器的键。

Returns:
  • ret (np.ndarray): 未归一化的值。

flatten

ding.utils.normalizer_helper.flatten(dataset: dict, path_lengths: list) dict[source]
Overview:

将数据集从 { key: [ n_episodes x max_path_length x dim ] } 展平为 { key : [ (n_episodes * sum(path_lengths)) x dim ] }

Arguments:
  • 数据集 (dict): 需要展平的数据集。

  • path_lengths (list): 每个情节的路径长度列表。

Returns:
  • 扁平化 (dict): 扁平化的数据集。

标准化器

class ding.utils.normalizer_helper.Normalizer(X)[source]
Overview:

父类,通过定义normalizeunnormalize方法进行子类化

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X)[源代码]
Overview:

初始化Normalizer对象。

Arguments:
  • X (np.ndarray): 需要标准化的数据。

normalize(*args, **kwargs)[source]
Overview:

标准化输入数据。

Arguments:
  • args (list): 传递给 normalize 函数的参数。

  • kwargs (dict): 传递给 normalize 函数的关键字参数。

unnormalize(*args, **kwargs)[source]
Overview:

对输入数据进行非规范化处理。

Arguments:
  • args (list): 传递给unnormalize函数的参数。

  • kwargs (dict): 传递给unnormalize函数的关键字参数。

高斯归一化器

class ding.utils.normalizer_helper.GaussianNormalizer(*args, **kwargs)[source]
Overview:

一个将数据标准化为零均值和单位方差的类。

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(*args, **kwargs)[source]
Overview:

初始化GaussianNormalizer对象。

Arguments:
  • args (list): 传递给父类 __init__ 函数的参数,即 Normalizer 类。

  • kwargs (dict): 传递给父类__init__函数的关键字参数,即Normalizer类。

normalize(x: ndarray) ndarray[source]
Overview:

标准化输入数据。

Arguments:
  • x (np.ndarray): 需要标准化的输入数据。

Returns:
  • ret (np.ndarray): 归一化后的数据。

unnormalize(x: ndarray) ndarray[来源]
Overview:

对输入数据进行非规范化处理。

Arguments:
  • x (np.ndarray): 需要反归一化的输入数据。

Returns:
  • ret (np.ndarray): 未归一化的数据。

CDFNormalizer

class ding.utils.normalizer_helper.CDFNormalizer(X)[source]
Overview:

一个通过使用边际CDF转换数据以使训练数据在每个维度上均匀的类。

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X)[源代码]
Overview:

初始化CDFNormalizer对象。

Arguments:
  • X (np.ndarray): 需要标准化的数据。

normalize(x: ndarray) ndarray[source]
Overview:

标准化输入数据。

Arguments:
  • x (np.ndarray): 输入数据。

Returns:
  • ret (np.ndarray): 归一化后的数据。

unnormalize(x: ndarray) ndarray[source]
Overview:

对输入数据进行非规范化处理。

Arguments:
  • x (np.ndarray): 输入数据。

Returns:
  • ret (np.ndarray):: 未归一化的数据。

wrap(fn_name: str, x: ndarray) ndarray[source]
Overview:

将给定的函数名称包装并应用于输入数据。

Arguments:
  • fn_name (str): 要应用的函数的名称。

  • x (np.ndarray): 输入数据。

Returns:
  • ret: 应用于输入数据的函数的输出。

CDFNormalizer1d

class ding.utils.normalizer_helper.CDFNormalizer1d(X: ndarray)[源代码]
Overview:

单维度的CDF标准化器。该类提供了使用累积分布函数(CDF)方法对数据进行标准化和反标准化的方法。

Interfaces:

__init__, __repr__, normalize, unnormalize.

__init__(X: ndarray)[source]
Overview:

初始化CDFNormalizer1d对象。

Arguments:
  • X (np.ndarray): 需要标准化的数据。

normalize(x: ndarray) ndarray[source]
Overview:

标准化输入数据。

Arguments:
  • x (np.ndarray): 需要标准化的数据。

Returns:
  • ret (np.ndarray): 归一化后的数据。

unnormalize(x: ndarray, eps: float = 0.0001) ndarray[source]
Overview:

对输入数据进行非规范化处理。

Arguments:
  • x (np.ndarray): 需要反归一化的数据。

  • eps (float): 用于数值稳定性的小值。默认为1e-4。

Returns:
  • ret (np.ndarray): 未归一化的数据。

经验累积分布函数

ding.utils.normalizer_helper.empirical_cdf(sample: ~numpy.ndarray) -> (<class 'numpy.ndarray'>, <class 'numpy.ndarray'>)[source]
Overview:

计算给定样本的经验累积分布函数(CDF)。

Arguments:
  • 样本 (np.ndarray): 用于计算经验CDF的输入样本。

Returns:
  • 分位数 (np.ndarray): 样本中的唯一值。

  • cumprob (np.ndarray): 对应于分位数的累积概率。

References:

至少二维

ding.utils.normalizer_helper.atleast_2d(x: ndarray) ndarray[source]
Overview:

确保输入数组至少有两个维度。

Arguments:
  • x (np.ndarray): 输入数组。

Returns:
  • ret (np.ndarray): 输入的数组,至少有两个维度。

LimitsNormalizer

class ding.utils.normalizer_helper.LimitsNormalizer(X)[source]
Overview:

一个在指定范围内标准化和非标准化值的类。该类将[xmin, xmax]范围内的值映射到[-1, 1]范围内。

Interfaces:

__init__, __repr__, normalize, unnormalize.

normalize(x: ndarray) ndarray[source]
Overview:

标准化输入值。

Argments:
  • x (np.ndarray): 需要标准化的输入值。

Returns:
  • ret (np.ndarray): 归一化后的值。

unnormalize(x: ndarray, eps: float = 0.0001) ndarray[源代码]
Overview:

对输入值进行反归一化。

Arguments:
  • x (np.ndarray): 需要反归一化的输入值。

  • eps (float): 用于裁剪的小值。默认为 1e-4。

Returns:
  • ret (np.ndarray): 未归一化的值。

orchestrator_launcher

请参考 ding/utils/orchestrator_launcher 获取更多详细信息。

OrchestratorLauncher

class ding.utils.orchestrator_launcher.OrchestratorLauncher(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack')[源代码]
Overview:

用于管理现有k8s集群中的di-orchestrator的对象

Interfaces:

__init__, create_orchestrator, delete_orchestrator

__init__(version: str, name: str = 'di-orchestrator', cluster: K8sLauncher | None = None, registry: str = 'diorchestrator', cert_manager_version: str = 'v1.3.1', cert_manager_registry: str = 'quay.io/jetstack') None[source]
Overview:

初始化OrchestratorLauncher对象。

Arguments:
  • 版本 (str): di-orchestrator 的版本。

  • 名称 (str): di-orchestrator 的名称。

  • 集群 (K8sLauncher): 用于部署 di-orchestrator 的 k8s 集群。

  • registry (str): 用于拉取镜像的docker仓库。

  • cert_manager_version (str): cert-manager的版本。

  • cert_manager_registry (str): 用于拉取cert-manager镜像的docker仓库。

_check_kubectl_tools() None[source]
Overview:

检查是否安装了kubectl工具。

create_orchestrator() None[源代码]
Overview:

在k8s集群中创建di-orchestrator。

delete_orchestrator() None[源代码]
Overview:

删除k8s集群中的di-orchestrator。

从配置创建组件

ding.utils.orchestrator_launcher.create_components_from_config(config: str) None[source]
Overview:

从配置文件创建组件。

Arguments:
  • 配置文件 (str): 配置文件。

等待准备就绪

ding.utils.orchestrator_launcher.wait_to_be_ready(namespace: str, component: str, timeout: int = 120) None[source]
Overview:

等待组件准备就绪。

Arguments:
  • 命名空间 (str): 组件的命名空间。

  • 组件 (str): 组件的名称。

  • 超时时间 (int): 等待的超时时间。

性能分析助手

请参考 ding/utils/profiler_helper 获取更多详细信息。

性能分析器

class ding.utils.profiler_helper.Profiler[source]
Overview:

一个用于分析代码执行的类。它可以作为上下文管理器或装饰器使用。

Interfaces:

__init__, mkdir, write_profile, profile.

__init__()[源代码]
Overview:

初始化 Profiler 对象。

mkdir(directory: str)[源代码]
OverView:

如果目录不存在,则创建一个目录。

Arguments:
  • 目录 (str): 要创建的目录的路径。

profile(folder_path='./tmp')[source]
OverView:

启用性能分析并将结果保存到文件。

Arguments:
  • folder_path (str): 保存分析文件的文件夹路径。默认为“./tmp”。

write_profile(pr: Profile, folder_path: str)[source]
OverView:

将分析结果写入文件。

Arguments:
  • pr (cProfile.Profile): 包含分析结果的分析器对象。

  • folder_path (str): 保存分析文件的文件夹路径。

pytorch_ddp_dist_helper

请参考 ding/utils/pytorch_ddp_dist_helper 了解更多详情。

获取排名

ding.utils.pytorch_ddp_dist_helper.get_rank() int[来源]
Overview:

获取当前进程在总世界大小中的排名

get_world_size

ding.utils.pytorch_ddp_dist_helper.get_world_size() int[source]
Overview:

获取world_size(数据并行训练中的总进程数)

allreduce

ding.utils.pytorch_ddp_dist_helper.allreduce(x: Tensor) None[source]
Overview:

在世界范围内减少张量 x

Arguments:
  • x (torch.Tensor): 要减少的张量

allreduce_async

ding.utils.pytorch_ddp_dist_helper.allreduce_async(name: str, x: Tensor) None[source]
Overview:

异步减少全局中的张量 x

Arguments:
  • 名称 (str): 张量的名称

  • x (torch.Tensor): 要减少的张量

reduce_data

ding.utils.pytorch_ddp_dist_helper.reduce_data(x: int | float | Tensor, dst: int) int | float | Tensor[source]
Overview:

将张量 x 减少到目标进程 dst

Arguments:
  • x (Union[int, float, torch.Tensor]): 要减少的张量

  • dst (int): 目标进程

allreduce_data

ding.utils.pytorch_ddp_dist_helper.allreduce_data(x: int | float | Tensor, op: str) int | float | Tensor[source]
Overview:

在世界范围内减少张量 x

Arguments:
  • x (Union[int, float, torch.Tensor]): 要减少的张量

  • op (str): 对数据执行的操作,支持 ['sum', 'avg']

get_group

ding.utils.pytorch_ddp_dist_helper.get_group(group_size: int) List[source]
Overview:

获取每个组的group_size的分组分割

Arguments:
  • group_size (int) 是 group_size

dist_mode

ding.utils.pytorch_ddp_dist_helper.dist_mode(func: Callable) Callable[源代码]
Overview:

包装函数,使其在每次调用前可以自动初始化和结束

Arguments:
  • func (Callable): 要包装的函数

dist_init

ding.utils.pytorch_ddp_dist_helper.dist_init(backend: str = 'nccl', addr: str | None = None, port: str | None = None, rank: int | None = None, world_size: int | None = None) Tuple[int, int][源代码]
Overview:

初始化分布式训练设置

Arguments:
  • backend (str): 分布式训练的后端,支持 ['nccl', 'gloo']

  • addr (str): 主节点的地址

  • 端口 (str): 主节点的端口

  • rank (int): 当前进程的排名

  • world_size (int): 进程的总数

dist_finalize

ding.utils.pytorch_ddp_dist_helper.dist_finalize() None[source]
Overview:

完成分布式训练资源的配置

DDP上下文

class ding.utils.pytorch_ddp_dist_helper.DDPContext[源代码]
Overview:

一个用于 linklink 分发的上下文管理器

Interfaces:

__init__, __enter__, __exit__

__init__() None[source]
Overview:

初始化 DDPContext

simple_group_split

ding.utils.pytorch_ddp_dist_helper.simple_group_split(world_size: int, rank: int, num_groups: int) List[源代码]
Overview:

根据worldsizeranknum_groups分割组

Arguments:
  • world_size (int): 世界大小

  • 排名 (int): 排名

  • num_groups (int): 组的数量

注意

如果输入有误,引发array split does not result in an equal division

to_ddp_config

ding.utils.pytorch_ddp_dist_helper.to_ddp_config(cfg: EasyDict) EasyDict[来源]
Overview:

将配置转换为ddp配置

Arguments:
  • cfg (EasyDict): 要转换的配置

注册表

请参考 ding/utils/registry 获取更多详细信息。

注册表

class ding.utils.registry.Registry(*args, **kwargs)[source]
Overview:

一个用于管理注册模块的辅助类,它扩展了一个字典并提供了注册功能。

Interfaces:

__init__, register, get, build, query, query_details

Examples (creating):
>>> some_registry = Registry({"default": default_module})
Examples (registering: normal way):
>>> def foo():
>>>     ...
>>> some_registry.register("foo_module", foo)
Examples (registering: decorator way):
>>> @some_registry.register("foo_module")
>>> @some_registry.register("foo_modeul_nickname")
>>> def foo():
>>>     ...
Examples (accessing):
>>> f = some_registry["foo_module"]
__init__(*args, **kwargs) None[source]
Overview:

初始化注册表对象。

Arguments:
  • args (Tuple): 传递给父类 __init__ 函数的参数, dict。

  • kwargs (Dict): 传递给父类 __init__ 函数的关键字参数, dict.

static _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) None[source]
Overview:

注册模块。

Arguments:
  • module_dict (dict): 用于存储模块的字典。

  • 模块名称 (str): 模块的名称。

  • 模块 (Callable): 要注册的模块。

  • force_overwrite (bool): 是否覆盖同名模块。

build(obj_type: str, *obj_args, **obj_kwargs) object[source]
Overview:

构建对象。

Arguments:
  • obj_type (str): 对象的类型。

  • obj_args (Tuple): 传递给对象的参数。

  • obj_kwargs (Dict): 传递给对象的关键字参数。

get(module_name: str) Callable[source]
Overview:

获取模块。

Arguments:
  • 模块名称 (str): 模块的名称。

query() Iterable[source]
Overview:

所有注册的模块名称。

query_details(aliases: Iterable | None = None) OrderedDict[source]
Overview:

获取已注册模块的详细信息。

Arguments:
  • 别名 (Optional[Iterable]): 模块的别名。

register(module_name: str | None = None, module: Callable | None = None, force_overwrite: bool = False) Callable[source]
Overview:

注册模块。

Arguments:
  • 模块名称 (Optional[str]): 模块的名称。

  • 模块 (Optional[Callable]): 要注册的模块。

  • force_overwrite (bool): 是否覆盖同名模块。

渲染助手

请参考 ding/utils/render_helper 获取更多详细信息。

渲染环境

ding.utils.render_helper.render_env(env, render_mode: str | None = 'rgb_array') ndarray[source]
Overview:

渲染环境的当前帧。

Arguments:
  • 环境 (gym.Env): DI-engine 环境实例。

  • render_mode (str): 渲染模式。

Returns:
  • 帧 (numpy.ndarray): [高度 * 宽度 * 通道]

渲染

ding.utils.render_helper.render(env: BaseEnv, render_mode: str | None = 'rgb_array') ndarray[source]
Overview:

渲染环境的当前帧。

Arguments:
  • env (BaseEnv): DI-engine 环境实例。

  • render_mode (str): 渲染模式。

Returns:
  • 帧 (numpy.ndarray): [H * W * C]

获取环境帧率

ding.utils.render_helper.get_env_fps(env) int[source]
Overview:

获取环境的帧率。

Arguments:
  • 环境 (gym.Env): DI-engine 环境实例。

Returns:
  • fps (int).

fps

ding.utils.render_helper.fps(env_manager: BaseEnvManager) int[source]
Overview:

渲染环境的帧率。

Arguments:
  • env (BaseEnvManager): DI-engine 环境管理器实例。

Returns:
  • fps (int).

调度助手

请参考 ding/utils/scheduler_helper 获取更多详细信息。

调度器

class ding.utils.scheduler_helper.Scheduler(merged_scheduler_config: EasyDict)[source]
Overview:

当trueskill指标停止改进时更新学习参数。 例如,一旦学习过程停滞,模型通常受益于减少熵权重。 此调度器读取指标量,如果在‘patience’个周期内没有看到改进, 则增加或减少相应的参数,这取决于‘schedule_mode’。

Arguments:
  • schedule_flag (bool): Indicates whether to use scheduler in training pipeline.

    默认值:False

  • schedule_mode (str): One of ‘reduce’, ‘add’,’multi’,’div’. The schecule_mode

    决定参数更新的方式。默认值:'reduce'。

  • factor (float)Amount (greater than 0) by which the parameter will be

    增加/减少。默认值:0.05

  • change_range (list): Indicates the minimum and maximum value

    参数可以达到的范围。默认值:[-1,1]

  • threshold (float): Threshold for measuring the new optimum,

    仅关注显著变化。默认值:1e-4。

  • optimize_mode (str): One of ‘min’, ‘max’, which indicates the sign of

    优化目标。Dynamic_threshold = last_metrics + threshold 在 max 模式下或 last_metrics - threshold 在 min 模式下。默认值:'min'

  • patience (int): Number of epochs with no improvement after which

    参数将会被更新。例如,如果patience = 2,那么我们将忽略前2个没有改进的时期,并且只有在第3个时期后如果指标仍然没有改进时才会更新参数。默认值:10。

  • cooldown (int): Number of epochs to wait before resuming

    参数更新后的正常操作。默认值:0。

Interfaces:

__init__, update_param, step

Property:

冷却中, 是否更好

__init__(merged_scheduler_config: EasyDict) None[源代码]
Overview:

初始化调度程序。

Arguments:
  • merged_scheduler_config (EasyDict): the scheduler config, which merges the user

    配置和默认配置

config = {'change_range': [-1, 1], 'cooldown': 0, 'factor': 0.05, 'optimize_mode': 'min', 'patience': 10, 'schedule_flag': False, 'schedule_mode': 'reduce', 'threshold': 0.0001}
property in_cooldown: bool
Overview:

检查调度器是否处于冷却期。如果在冷却期,调度器将忽略任何不良的周期。

is_better(cur: float) bool[source]
Overview:

检查当前指标是否相对于阈值优于上一个指标。

Args:
  • cur (float): 当前指标

step(metrics: float, param: float) float[source]
Overview:

决定是否更新计划的参数

Args:
  • 指标 (float): 当前输入指标

  • 参数 (float): 需要更新的参数

Returns:
  • step_param (float): 一步之后的参数

update_param(param: float) float[source]
Overview:

更新调度参数

Args:
  • 参数 (float): 需要更新的参数

Returns:
  • 更新后的参数 (float): 更新后的参数

段树

请参考 ding/utils/segment_tree 获取更多详细信息。

njit

ding.utils.segment_tree.njit()[source]
Overview:

使用numba编译函数的装饰器。

线段树

class ding.utils.segment_tree.SegmentTree(capacity: int, operation: Callable, neutral_element: float | None = None)[source]
Overview:

线段树数据结构,由树状数组实现。只有叶子节点是真实值,非叶子节点是对其左右子节点进行一些操作。

Interfaces:

__init__, reduce, __setitem__, __getitem__

__init__(capacity: int, operation: Callable, neutral_element: float | None = None) None[source]
Overview:

初始化线段树。树的根节点位于索引1。

Arguments:
  • 容量 (int): 树的容量(叶子节点的数量),应该是2的幂。

  • 操作 (function): 用于构建树的操作函数,例如 sum, max, min 等。

  • 中性元素 (floatNone): 中性元素的值,用于初始化树中所有节点的值。

_compile() None[source]
Overview:

使用numba编译函数。

reduce(start: int = 0, end: int | None = None) float[source]
Overview:

在范围 [start, end) 内减少树

Arguments:
  • 开始 (int): 起始索引(相对索引,第一个叶子节点为0),默认设置为0

  • 结束 (intNone): 结束索引(相对索引),默认设置为 self.capacity

Returns:
  • reduce_result (float): 减少结果值,取决于数据类型和操作

求和线段树

class ding.utils.segment_tree.SumSegmentTree(capacity: int)[source]
Overview:

求和线段树,继承自 SegmentTree。通过传递 operation='sum' 进行初始化。

Interfaces:

__init__, find_prefixsum_idx

__init__(capacity: int) None[source]
Overview:

通过传递 operation='sum' 初始化求和线段树

Arguments:
  • 容量 (int): 树的容量(叶子节点的数量)。

find_prefixsum_idx(prefixsum: float, trust_caller: bool = True) int[source]
Overview:

找到最高的非零索引 i,sum_{j}leaf[j] <= prefixsum(其中 0 <= j < i) 并且 sum_{j}leaf[j] > prefixsum(其中 0 <= j < i+1)

Arguments:
  • prefixsum (float): 目标前缀和。

  • trust_caller (bool): Whether to trust caller, which means whether to check whether this tree’s sum is greater than the input prefixsum by calling reduce function.

    默认设置为True。

Returns:
  • idx (int): 符合条件的索引。

最小段树

class ding.utils.segment_tree.MinSegmentTree(capacity: int)[source]
Overview:

最小段树,继承自 SegmentTree。通过传递 operation='min' 进行初始化。

Interfaces:

__init__

__init__(capacity: int) None[源代码]
Overview:

通过传递 operation='min' 初始化求和线段树

Arguments:
  • 容量 (int): 树的容量(叶子节点的数量)。

_setitem

ding.utils.segment_tree._setitem(tree: ndarray, idx: int, val: float, operation: str) None
Overview:

设置 tree[idx] = val; 然后更新相关节点。

Arguments:
  • 树 (np.ndarray): 树数组。

  • idx (int): 叶子节点的索引。

  • val (float): 将被赋值给 leaf[idx] 的值。

  • 操作 (str): 用于构建树的操作函数,例如 sum, max, min 等。

_reduce

ding.utils.segment_tree._reduce(tree: ndarray, start: int, end: int, neutral_element: float, operation: str) float
Overview:

在范围 [start, end) 内减少树

Arguments:
  • 树 (np.ndarray): 树数组。

  • 开始 (int): 开始索引(相对索引,第一个叶子节点为0)。

  • 结束 (int): 结束索引(相对索引)。

  • neutral_element (float): 中性元素的值,用于初始化树中所有节点的值。

  • 操作 (str): 用于构建树的操作函数,例如 sum, max, min 等。

_find_prefixsum_idx

ding.utils.segment_tree._find_prefixsum_idx(tree: ndarray, capacity: int, prefixsum: float, neutral_element: float) int
Overview:

找到最高的非零索引 i,sum_{j}leaf[j] <= prefixsum(其中 0 <= j < i) 并且 sum_{j}leaf[j] > prefixsum(其中 0 <= j < i+1)

Arguments:
  • 树 (np.ndarray): 树数组。

  • 容量 (int): 树的容量(叶子节点的数量)。

  • prefixsum (float): 目标前缀和。

  • neutral_element (float): 中性元素的值,用于初始化树中所有节点的值。

slurm_helper

请参考 ding/utils/slurm_helper 获取更多详细信息。

获取IP

ding.utils.slurm_helper.get_ip() str[source]
Overview:

获取当前节点的IP地址

获取管理节点IP

ding.utils.slurm_helper.get_manager_node_ip(node_ip: str | None = None) str[source]
Overview:

查找slurm集群的管理节点并返回节点IP

Arguments:
  • node_ip (Optional[str]): 当前节点的IP地址

获取类信息

ding.utils.slurm_helper.get_cls_info() Dict[str, list][source]
Overview:

获取集群信息

节点到分区

ding.utils.slurm_helper.node_to_partition(target_node: str) Tuple[str, str][source]
Overview:

获取目标节点的分区

Arguments:
  • target_node (str): 目标节点

node_to_host

ding.utils.slurm_helper.node_to_host(node: str) str[source]
Overview:

获取节点的主机

Arguments:
  • 节点 (str): 节点

find_free_port_slurm

ding.utils.slurm_helper.find_free_port_slurm(node: str) int[源代码]
Overview:

在节点上找到一个空闲端口

Arguments:
  • 节点 (str): 节点

system_helper

请参考 ding/utils/system_helper 获取更多详情。

获取IP

ding.utils.system_helper.get_ip() str[source]
Overview:

获取套接字的ip(host)

Returns:
  • ip(str): 对应的IP地址

获取进程ID

ding.utils.system_helper.get_pid() int[source]
Overview:

os.getpid

获取任务UID

ding.utils.system_helper.get_task_uid() str[source]
Overview:

获取slurm的job_idpiduid

传播线程

class ding.utils.system_helper.PropagatingThread(group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None)[source]
Overview:

Thread的子类,将线程中的执行异常传播给调用者

Interfaces:

run, join

Examples:
>>> def func():
>>>     raise Exception()
>>> t = PropagatingThread(target=func, args=())
>>> t.start()
>>> t.join()
join() Any[source]
Overview:

加入线程

run() None[source]
Overview:

运行线程

find_free_port

ding.utils.system_helper.find_free_port(host: str) int[source]
Overview:

查找空闲端口列表并返回一个

Arguments:
  • 主机 (str): 主机

time_helper_base

请参考 ding/utils/time_helper_base 获取更多详细信息。

时间包装器

class ding.utils.time_helper_base.TimeWrapper[source]
Overview:

定义TimeWrapper类的抽象类方法

Interfaces:

wrapper, start_time, end_time

classmethod end_time()[来源]
Overview:

抽象类方法,停止计时

classmethod start_time()[source]
Overview:

抽象类方法,开始计时

classmethod wrapper(fn)[source]
Overview:

类方法包装器,包装一个函数并自动返回其运行时间

Arguments:
  • fn (function): 需要被包装和计时的函数

time_helper_cuda

请参考 ding/utils/time_helper_cuda 获取更多详细信息。

get_cuda_time_wrapper

ding.utils.time_helper_cuda.get_cuda_time_wrapper() Callable[[], TimeWrapper][源代码]
Overview:

返回TimeWrapperCuda类,此包装器旨在确保在没有CUDA设备时的兼容性

Returns:
  • TimeWrapperCuda(class): 参见 TimeWrapperCuda

注意

必须使用 torch.cuda.synchronize(),参考:<https://blog.csdn.net/u013548568/article/details/81368019>

时间助手

请参考 ding/utils/time_helper 获取更多详细信息。

build_time_helper

ding.utils.time_helper.build_time_helper(cfg: EasyDict | None = None, wrapper_type: str | None = None) Callable[[], TimeWrapper][source]
Overview:

构建时间助手

Arguments:
  • cfg (dict):

    配置文件是一个多级字典,包含诸如评估、通用、模型、训练等大领域,每个大领域都有其较小的领域。

  • wrapper_type (str): 返回的包装器类型,支持 ['time', 'cuda']

Returns:
  • time_wrapper (TimeWrapper):

    返回相应的时间包装器,参考:ding.utils.timehelper.TimeWrapperTimeding.utils.timehelper.get_cuda_time_wrapper

EasyTimer

class ding.utils.time_helper.EasyTimer(cuda=True)[源代码]
Overview:

一个体面的计时器包装器,可以轻松使用。

Interfaces:

__init__, __enter__, __exit__

Example:
>>> wait_timer = EasyTimer()
>>> with wait_timer:
>>>    func(...)
>>> time_ = wait_timer.value  # in second
__init__(cuda=True)[source]
Overview:

初始化类 EasyTimer

Arguments:
  • cuda (bool): 是否使用cuda类型构建计时器

时间包装时间

class ding.utils.time_helper.TimeWrapperTime[source]
Overview:

一个继承自 TimeWrapper 类的类方法

Interfaces:

start_time, end_time

classmethod end_time()[source]
Overview:

TimeWrapper类中实现并重写end_time方法

Returns:
  • 时间(float): start_time 和 end_time 之间的时间

classmethod start_time()[source]
Overview:

实现并重写TimeWrapper类中的start_time方法

看门狗

class ding.utils.time_helper.WatchDog(timeout: int = 1)[source]
Overview:

简单的看门狗定时器用于检测超时

Arguments:
  • 超时时间 (int): watchdog [seconds] 的超时值。

注意

如果在此值之前未重置,TimeourError 将会被抛出。

Interfaces:

start, stop

Examples:
>>> watchdog = WatchDog(x) # x is a timeout value
>>> ...
>>> watchdog.start()
>>> ... # Some function
__init__(timeout: int = 1)[source]
Overview:

使用timeout值初始化看门狗。

Arguments:
  • 超时时间 (int): watchdog [seconds] 的超时值。

static _event(signum: Any, frame: Any)[source]
Overview:

看门狗的事件处理程序。

Arguments:
  • signum (Any): 信号编号。

  • frame (Any): 当前堆栈帧。

start()[source]
Overview:

启动看门狗。

stop()[source]
Overview:

使用alarm(0)SIGALRMSIG_DFL信号停止看门狗。

loader.base

请参考 ding/utils/loader/base 获取更多详细信息。

ILoaderClass

class ding.utils.loader.base.ILoaderClass[source]
Overview:

加载器的基类。

Interfaces:

__init__, _load, load, check, __call__, __and__, __or__, __rshift__

__check(value: _ValueType) bool
Overview:

检查值是否有效。

Arguments:
  • 值 (_ValueType): 要检查的值。

__load(value: _ValueType) _ValueType
Overview:

加载值。

Arguments:
  • 值 (_ValueType): 要加载的值。

abstract _load(value: _ValueType) _ValueType[source]
Overview:

加载值。

Arguments:
  • 值 (_ValueType): 要加载的值。

check(value: _ValueType) bool[source]
Overview:

检查值是否有效。

Arguments:
  • 值 (_ValueType): 要检查的值。

load(value: _ValueType) _ValueType[source]
Overview:

加载值。

Arguments:
  • 值 (_ValueType): 要加载的值。

loader.collection

请参考 ding/utils/loader/collection 获取更多详细信息。

CollectionError

class ding.utils.loader.collection.CollectionError(errors: List[Tuple[int, Exception]])[source]
Overview:

收集错误。

Interfaces:

__init__, errors

Properties:

errors

__init__(errors: List[Tuple[int, Exception]])[源代码]
Overview:

初始化CollectionError。

Arguments:
  • 错误 (COLLECTION_ERRORS): 错误信息。

_abc_impl = <_abc._abc_data object>
property errors: List[Tuple[int, Exception]]
Overview:

获取错误。

集合

ding.utils.loader.collection.collection(loader, type_back: bool = True) ILoaderClass[source]
Overview:

创建一个集合加载器。

Arguments:
  • 加载器 (ILoaderClass): 加载器。

  • type_back (bool): 是否转换回类型。

tuple

ding.utils.loader.collection.tuple_(*loaders) ILoaderClass[source]
Overview:

创建一个元组加载器。

Arguments:
  • 加载器 (tuple): 加载器。

长度

ding.utils.loader.collection.length(min_length: int | None = None, max_length: int | None = None) ILoaderClass[源代码]
Overview:

创建一个长度加载器。

Arguments:
  • min_length (int): 最小长度。

  • max_length (int): 最大长度。

length_is

ding.utils.loader.collection.length_is(length_: int) ILoaderClass[source]
Overview:

创建一个长度加载器。

Arguments:

包含

ding.utils.loader.collection.contains(content) ILoaderClass[source]
Overview:

创建一个包含加载器。

Arguments:
  • 内容 (Any): 内容。

cofilter

ding.utils.loader.collection.cofilter(checker: Callable[[Any], bool], type_back: bool = True) ILoaderClass[source]
Overview:

创建一个cofilter加载器。

Arguments:
  • 检查器 (Callable[[Any], bool]): 检查器。

  • type_back (bool): 是否转换回类型。

tpselector

ding.utils.loader.collection.tpselector(*indices) ILoaderClass[source]
Overview:

创建一个元组选择器加载器。

Arguments:
  • 索引 (tuple): 索引。

loader.dict

请参考 ding/utils/loader/dict 获取更多详细信息。

字典错误

class ding.utils.loader.dict.DictError(errors: Mapping[str, Exception])[source]
Overview:

字典错误。

Interfaces:

__init__, errors

Properties:

errors

__init__(errors: Mapping[str, Exception])[源代码]
Overview:

初始化DictError。

Arguments:
  • 错误 (DICT_ERRORS): 错误信息。

_abc_impl = <_abc._abc_data object>
property errors: Mapping[str, Exception]
Overview:

获取错误。

dict

ding.utils.loader.dict.dict_(**kwargs) ILoaderClass[source]
Overview:

创建一个字典加载器。

Arguments:
  • kwargs (Mapping[str, ILoaderClass]): 加载器。

loader.exception

请参考 ding/utils/loader/exception 获取更多详细信息。

复合结构错误

class ding.utils.loader.exception.CompositeStructureError[source]
Overview:

复合结构错误。

Interfaces:

__init__, errors

Properties:

errors

_abc_impl = <_abc._abc_data object>
abstract property errors: List[Tuple[int | str, Exception]]
Overview:

获取错误。

loader.mapping

请参考 ding/utils/loader/mapping 获取更多详细信息。

映射错误

class ding.utils.loader.mapping.MappingError(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[source]
Overview:

映射错误。

Interfaces:

__init__, errors

__init__(key_errors: List[Tuple[str, Exception]], value_errors: List[Tuple[str, Exception]])[source]
Overview:

初始化MappingError。

Arguments:
  • 键错误 (MAPPING_ERRORS): 键错误。

  • value_errors (MAPPING_ERRORS): 值错误。

_abc_impl = <_abc._abc_data object>
errors() List[Tuple[str, Exception]][source]
Overview:

获取错误。

key_errors() List[Tuple[str, Exception]][source]
Overview:

获取关键错误。

value_errors() List[Tuple[str, Exception]][source]
Overview:

获取值错误。

映射

ding.utils.loader.mapping.mapping(key_loader, value_loader, type_back: bool = True) ILoaderClass[source]
Overview:

创建一个映射加载器。

Arguments:
  • key_loader (ILoaderClass): 密钥加载器。

  • value_loader (ILoaderClass): 值加载器。

  • type_back (bool): 是否转换回类型。

mpfilter

ding.utils.loader.mapping.mpfilter(check: Callable[[Any, Any], bool], type_back: bool = True) ILoaderClass[source]
Overview:

创建一个映射过滤器加载器。

Arguments:
  • 检查 (Callable[[Any, Any], bool]): 检查函数。

  • type_back (bool): 是否转换回类型。

mpkeys

ding.utils.loader.mapping.mpkeys() ILoaderClass[source]
Overview:

创建一个映射键加载器。

mpvalues

ding.utils.loader.mapping.mpvalues() ILoaderClass[source]
Overview:

创建一个映射值加载器。

mpitems

ding.utils.loader.mapping.mpitems() ILoaderClass[source]
Overview:

创建一个映射项加载器。

项目

ding.utils.loader.mapping.item(key) ILoaderClass[source]
Overview:

创建一个项目加载器。

Arguments:
  • 键 (Any): 键。

item_or

ding.utils.loader.mapping.item_or(key, default) ILoaderClass[source]
Overview:

创建一个项目或加载器。

Arguments:
  • 键 (Any): 键。

  • 默认值 (Any): 默认值。

loader.norm

请参考 ding/utils/loader/norm 获取更多详细信息。

_callable_to_norm

ding.utils.loader.norm._callable_to_norm(func: Callable[[Any], Any]) INormClass[source]
Overview:

将可调用对象转换为规范。

Arguments:
  • func (Callable[[Any], Any]): 要转换的可调用对象。

norm

ding.utils.loader.norm.norm(value) INormClass[source]
Overview:

将值转换为规范。

Arguments:
  • 值 (Any): 要转换的值。

normfunc

ding.utils.loader.norm.normfunc(func)[source]
Overview:

将函数转换为标准函数。

Arguments:
  • func (Callable[[Any], Any]): 要转换的函数。

_unary

ding.utils.loader.norm._unary(a: INormClass, func: Callable[[Any], Any]) INormClass[source]
Overview:

创建一个一元范数。

Arguments:
  • a (INormClass): 规范。

  • func (UNARY_FUNC): 函数。

_binary

ding.utils.loader.norm._binary(a: INormClass, b: INormClass, func: Callable[[Any, Any], Any]) INormClass[source]
Overview:

创建一个二元范数。

Arguments:
  • a (INormClass): 第一个规范。

  • b (INormClass): 第二个范数。

  • func (BINARY_FUNC): 函数。

_binary_reducing

ding.utils.loader.norm._binary_reducing(func: Callable[[Any, Any], Any], zero)[source]
Overview:

创建一个二元归约范数。

Arguments:
  • func (BINARY_FUNC): 函数。

  • 零值 (Any): 零值。

INormClass

class ding.utils.loader.norm.INormClass[source]
Overview:

范数类。

Interfaces:

__call__, __add__, __radd__, __sub__, __rsub__, __mul__, __rmul__, __matmul__, __rmatmul__, __truediv__, __rtruediv__, __floordiv__, __rfloordiv__, __mod__, __rmod__, __pow__, __rpow__, __lshift__, __rlshift__, __rshift__, __rrshift__, __and__, __rand__, __or__, __ror__, __xor__, __rxor__, __invert__, __pos__, __neg__, __eq__, __ne__, __lt__, __le__, __gt__, __ge__

abstract _call(value)[source]
Overview:

调用规范。

Arguments:
  • 值 (Any): 需要标准化的值。

lcmp

ding.utils.loader.norm.lcmp(first, *items)[source]
Overview:

比较项目。

Arguments:
  • 第一个 (Any): 第一个项目。

  • items (Any): 其他项目。

loader.number

请参考 ding/utils/loader/number 获取更多详细信息。

数值

ding.utils.loader.number.numeric(int_ok: bool = True, float_ok: bool = True, inf_ok: bool = True) ILoaderClass[source]
Overview:

创建一个数字加载器。

Arguments:
  • int_ok (bool): 是否允许整数。

  • float_ok (bool): 是否允许浮点数。

  • inf_ok (bool): 是否允许inf。

间隔

ding.utils.loader.number.interval(left: int | float | None = None, right: int | float | None = None, left_ok: bool = True, right_ok: bool = True, eps=0.0) ILoaderClass[source]
Overview:

创建一个间隔加载器。

Arguments:
  • left (Optional[NUMBER_TYPING]): 左边界。

  • right (Optional[NUMBER_TYPING]): 右边界。

  • left_ok (bool): 是否允许左边界。

  • right_ok (bool): 是否允许右边界。

  • eps (float): 该epsilon值。

is_negative

ding.utils.loader.number.is_negative() ILoaderClass[source]
Overview:

创建一个负加载器。

is_positive

ding.utils.loader.number.is_positive() ILoaderClass[source]
Overview:

创建一个积极的加载器。

非负数

ding.utils.loader.number.non_negative() ILoaderClass[source]
Overview:

创建一个非负加载器。

非正数

ding.utils.loader.number.non_positive() ILoaderClass[source]
Overview:

创建一个非正加载器。

负数

ding.utils.loader.number.negative() ILoaderClass[source]
Overview:

创建一个负加载器。

正数

ding.utils.loader.number.positive() ILoaderClass[source]
Overview:

创建一个积极的加载器。

_math_binary

ding.utils.loader.number._math_binary(func: Callable[[Any, Any], Any], attachment) ILoaderClass[source]
Overview:

创建一个数学二进制加载器。

Arguments:
  • func (Callable[[Any, Any], Any]): 函数。

  • 附件 (Any): 附件。

plus

ding.utils.loader.number.plus(addend) ILoaderClass[source]
Overview:

创建一个加号加载器。

Arguments:
  • 加数 (Any): 加数。

减去

ding.utils.loader.number.minus(subtrahend) ILoaderClass[source]
Overview:

创建一个减号加载器。

Arguments:
  • 被减数 (Any): 被减数。

minus_with

ding.utils.loader.number.minus_with(minuend) ILoaderClass[source]
Overview:

创建一个减号加载器。

Arguments:
  • 被减数 (Any): 被减数。

ding.utils.loader.number.multi(multiplier) ILoaderClass[source]
Overview:

创建一个多加载器。

Arguments:
  • 乘数 (Any): 乘数。

除法

ding.utils.loader.number.divide(divisor) ILoaderClass[source]
Overview:

创建一个分割加载器。

Arguments:
  • 除数 (Any): 除数。

divide_with

ding.utils.loader.number.divide_with(dividend) ILoaderClass[source]
Overview:

创建一个分割加载器。

Arguments:
  • dividend (Any): 被除数。

power

ding.utils.loader.number.power(index) ILoaderClass[source]
Overview:

创建一个动力装载机。

Arguments:
  • 索引 (Any): 索引。

power_with

ding.utils.loader.number.power_with(base) ILoaderClass[source]
Overview:

创建一个动力装载机。

Arguments:
  • 基础 (Any): 基础。

msum

ding.utils.loader.number.msum(*items) ILoaderClass[source]
Overview:

创建一个求和加载器。

Arguments:
  • 项目 (tuple): 项目。

mmulti

ding.utils.loader.number.mmulti(*items) ILoaderClass[source]
Overview:

创建一个多加载器。

Arguments:
  • 项目 (tuple): 项目。

_msinglecmp

ding.utils.loader.number._msinglecmp(first, op, second) ILoaderClass[source]
Overview:

创建一个单一的比较加载器。

Arguments:
  • 第一个 (Any): 第一个项目。

  • op (str): 操作符。

  • second (Any): 第二个项目。

mcmp

ding.utils.loader.number.mcmp(first, *items) ILoaderClass[source]
Overview:

创建一个多比较加载器。

Arguments:
  • 第一个 (Any): 第一个项目。

  • 项目 (tuple): 项目。

loader.string

请参考 ding/utils/loader/string 获取更多详细信息。

枚举

ding.utils.loader.string.enum(*items, case_sensitive: bool = True) ILoaderClass[source]
Overview:

创建一个枚举加载器。

Arguments:
  • 项目 (Iterable[str]): 项目。

  • case_sensitive (bool): 是否区分大小写。

_to_regexp

ding.utils.loader.string._to_regexp(regexp) Pattern[source]
Overview:

将正则表达式转换为re.Pattern。

Arguments:
  • regexp (Union[str, re.Pattern]): 正则表达式。

rematch

ding.utils.loader.string.rematch(regexp: str | Pattern) ILoaderClass[source]
Overview:

创建一个重新匹配的加载器。

Arguments:
  • regexp (Union[str, re.Pattern]): 正则表达式。

regrep

ding.utils.loader.string.regrep(regexp: str | Pattern, group: int = 0) ILoaderClass[source]
Overview:

创建一个regrep加载器。

Arguments:
  • regexp (Union[str, re.Pattern]): 正则表达式。

  • 组 (int): 组。

loader.types

请参考 ding/utils/loader/types 获取更多详细信息。

is_type

ding.utils.loader.types.is_type(type_: type) ILoaderClass[source]
Overview:

创建一个类型加载器。

Arguments:
  • type_ (type): 类型。

to_type

ding.utils.loader.types.to_type(type_: type) ILoaderClass[source]
Overview:

创建一个类型加载器。

Arguments:
  • type_ (type): 类型。

is_callable

ding.utils.loader.types.is_callable() ILoaderClass[source]
Overview:

创建一个可调用的加载器。

属性

ding.utils.loader.types.prop(attr_name: str) ILoaderClass[source]
Overview:

创建一个属性加载器。

Arguments:
  • attr_name (str): 属性名称。

方法

ding.utils.loader.types.method(method_name: str) ILoaderClass[source]
Overview:

创建一个方法加载器。

Arguments:
  • method_name (str): 方法名称。

fcall

ding.utils.loader.types.fcall(*args, **kwargs) ILoaderClass[source]
Overview:

创建一个函数加载器。

Arguments:
  • args (Tuple[Any]): 参数。

  • kwargs (Dict[str, Any]): 关键字参数。

fpartial

ding.utils.loader.types.fpartial(*args, **kwargs) ILoaderClass[source]
Overview:

创建一个部分函数加载器。

Arguments:
  • args (Tuple[Any]): 参数。

  • kwargs (Dict[str, Any]): 关键字参数。

loader.utils

请参考 ding/utils/loader/utils 获取更多详细信息。

保持

ding.utils.loader.utils.keep() ILoaderClass[source]
Overview:

创建一个保持加载器。

原始

ding.utils.loader.utils.raw(value) ILoaderClass[source]
Overview:

创建一个原始加载器。

可选的

ding.utils.loader.utils.optional(loader) ILoaderClass[source]
Overview:

创建一个可选的加载器。

Arguments:
  • 加载器 (ILoaderClass): 加载器。

仅检查

ding.utils.loader.utils.check_only(loader) ILoaderClass[source]
Overview:

创建一个仅检查的加载器。

Arguments:
  • 加载器 (ILoaderClass): 加载器。

检查

ding.utils.loader.utils.check(loader) ILoaderClass[source]
Overview:

创建一个检查加载器。

Arguments:
  • 加载器 (ILoaderClass): 加载器。