binder

基于间隔的时间序列分类在aeon中

基于间隔的方法着眼于整个序列的相位依赖间隔,从选定的子序列中计算摘要统计量以用于分类。

当前基于间隔的方法在aeon中实现。时间序列森林(TSF)[1],随机间隔谱集成(RISE)[2],监督时间序列森林(STSF)[3]和随机STSF [4],规范间隔森林(CIF)[5],多样化表示规范间隔森林(DrCIF)和QUANT [6]。大多数具有分类多变量序列的能力。

在本笔记本中,我们将演示如何在ItalyPowerDemand和BasicMotions数据集上使用这些分类器。

设置

我们可以像这样列出此类别中的所有分类器

[3]:
import warnings

from sklearn import metrics

from aeon.classification.interval_based import (
    RSTSF,
    CanonicalIntervalForestClassifier,
    DrCIFClassifier,
    QUANTClassifier,
    RandomIntervalSpectralEnsembleClassifier,
    SupervisedTimeSeriesForest,
    TimeSeriesForestClassifier,
)
from aeon.datasets import load_basic_motions, load_italy_power_demand
from aeon.utils.discovery import all_estimators

warnings.filterwarnings("ignore")
all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 18
     15 from aeon.utils.discovery import all_estimators
     17 warnings.filterwarnings("ignore")
---> 18 all_estimators("classifier", tag_filter={"algorithm_type": "interval"})

File C:\Code\aeon\aeon\utils\discovery.py:121, in all_estimators(type_filter, exclude_types, tag_filter, exclude_tags, include_sklearn, return_names)
    116 if (
    117     any(part in modules_to_ignore for part in module_parts)
    118     or "._" in module_name
    119 ):
    120     continue
--> 121 module = import_module(module_name)
    123 classes = inspect.getmembers(module, inspect.isclass)
    124 # skip private estimators and those not implemented in aeon

File ~\AppData\Local\Programs\Python\Python39\lib\importlib\__init__.py:127, in import_module(name, package)
    125             break
    126         level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)

File <frozen importlib._bootstrap>:1030, in _gcd_import(name, package, level)

File <frozen importlib._bootstrap>:1007, in _find_and_load(name, import_)

File <frozen importlib._bootstrap>:986, in _find_and_load_unlocked(name, import_)

File <frozen importlib._bootstrap>:680, in _load_unlocked(spec)

File <frozen importlib._bootstrap_external>:850, in exec_module(self, module)

File <frozen importlib._bootstrap>:228, in _call_with_frames_removed(f, *args, **kwds)

File C:\Code\aeon\aeon\classification\convolution_based\__init__.py:12
      1 """Convolution-based time series classifiers."""
      3 __all__ = [
      4     "RocketClassifier",
      5     "MiniRocketClassifier",
   (...)
      9     "MultiRocketHydraClassifier",
     10 ]
---> 12 from aeon.classification.convolution_based._arsenal import Arsenal
     13 from aeon.classification.convolution_based._hydra import HydraClassifier
     14 from aeon.classification.convolution_based._minirocket import MiniRocketClassifier

File C:\Code\aeon\aeon\classification\convolution_based\_arsenal.py:20
     18 from aeon.base._base import _clone_estimator
     19 from aeon.classification.base import BaseClassifier
---> 20 from aeon.transformations.collection.convolution_based import (
     21     MiniRocket,
     22     MultiRocket,
     23     Rocket,
     24 )
     27 class Arsenal(BaseClassifier):
     28     """
     29     Arsenal ensemble.
     30
   (...)
    120     >>> y_pred = clf.predict(X_test)
    121     """

File C:\Code\aeon\aeon\transformations\collection\convolution_based\__init__.py:13
     11 from ._hydra import HydraTransformer
     12 from ._minirocket import MiniRocket
---> 13 from ._minirocket_mv import MiniRocketMultivariateVariable
     14 from ._multirocket import MultiRocket
     15 from ._rocket import Rocket

File C:\Code\aeon\aeon\transformations\collection\convolution_based\_minirocket_mv.py:303
    290     return X_2d_t, lengths
    293 # code below from the orignal authors: https://github.com/angus924/minirocket
    296 @njit(
    297     "float32[:](float32[:,:],int32[:],int32[:],int32[:],int32[:],int32[:],float32[:],"
    298     "optional(int32))",
    299     fastmath=True,
    300     parallel=False,
    301     cache=True,
    302 )
--> 303 def _fit_biases_multi_var(
    304     X,
    305     L,
    306     num_channels_per_combination,
    307     channel_indices,
    308     dilations,
    309     num_features_per_dilation,
    310     quantiles,
    311     seed,
    312 ):
    313     if seed is not None:
    314         np.random.seed(seed)

File C:\Code\aeon\venv\lib\site-packages\numba\core\decorators.py:232, in _jit.<locals>.wrapper(func)
    230     with typeinfer.register_dispatcher(disp):
    231         for sig in sigs:
--> 232             disp.compile(sig)
    233         disp.disable_compile()
    234 return disp

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
    903 with ev.trigger_event("numba:compile", data=ev_details):
    904     try:
--> 905         cres = self._compiler.compile(args, return_type)
    906     except errors.ForceLiteralArg as e:
    907         def folded(args, kws):

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
   1080 oldtoken = newtoken
   1081 # Errors can appear when the type set is incomplete; only
   1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
   1084 newtoken = self.get_state_token()
   1085 self.debug.propagate_finished()

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
    157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
    158                                        lineno=loc.line):
    159     try:
--> 160         constraint(typeinfer)
    161     except ForceLiteralArg as e:
    162         errors.append(e)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
    581     fnty = typevars[self.func].getone()
    582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
    604     fnty = fnty.instance_type
    605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
    607 except ForceLiteralArg as e:
    608     # Adjust for bound methods
    609     folding_args = ((fnty.this,) + tuple(self.args)
    610                     if isinstance(fnty, types.BoundFunction)
    611                     else self.args)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1577, in TypeInferer.resolve_call(self, fnty, pos_args, kw_args)
   1574     return sig
   1575 else:
   1576     # Normal non-recursive call
-> 1577     return self.context.resolve_function_type(fnty, pos_args, kw_args)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:196, in BaseContext.resolve_function_type(self, func, args, kws)
    194 # Prefer user definition first
    195 try:
--> 196     res = self._resolve_user_function_type(func, args, kws)
    197 except errors.TypingError as e:
    198     # Capture any typing error
    199     last_exception = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:248, in BaseContext._resolve_user_function_type(self, func, args, kws, literals)
    244         return self.resolve_function_type(func_type, args, kws)
    246 if isinstance(func, types.Callable):
    247     # XXX fold this into the __call__ attribute logic?
--> 248     return func.get_call_type(self, args, kws)

File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:308, in BaseFunction.get_call_type(self, context, args, kws)
    305         nolitargs = tuple([_unlit_non_poison(a) for a in args])
    306         nolitkws = {k: _unlit_non_poison(v)
    307                     for k, v in kws.items()}
--> 308         sig = temp.apply(nolitargs, nolitkws)
    309 except Exception as e:
    310     if (utils.use_new_style_errors() and not
    311             isinstance(e, errors.NumbaError)):

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:350, in AbstractTemplate.apply(self, args, kws)
    348 def apply(self, args, kws):
    349     generic = getattr(self, "generic")
--> 350     sig = generic(args, kws)
    351     # Enforce that *generic()* must return None or Signature
    352     if sig is not None:

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:613, in _OverloadFunctionTemplate.generic(self, args, kws)
    607 """
    608 Type the overloaded function by compiling the appropriate
    609 implementation for the given args.
    610 """
    611 from numba.core.typed_passes import PreLowerStripPhis
--> 613 disp, new_args = self._get_impl(args, kws)
    614 if disp is None:
    615     return

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:712, in _OverloadFunctionTemplate._get_impl(self, args, kws)
    708 except KeyError:
    709     # pass and try outside the scope so as to not have KeyError with a
    710     # nested addition error in the case the _build_impl fails
    711     pass
--> 712 impl, args = self._build_impl(cache_key, args, kws)
    713 return impl, args

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:816, in _OverloadFunctionTemplate._build_impl(self, cache_key, args, kws)
    814 # Make sure that the implementation can be fully compiled
    815 disp_type = types.Dispatcher(disp)
--> 816 disp_type.get_call_type(self.context, args, kws)
    817 if cache_key is not None:
    818     self._impl_cache[cache_key] = disp, args

File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:541, in Dispatcher.get_call_type(self, context, args, kws)
    534 def get_call_type(self, context, args, kws):
    535     """
    536     Resolve a call to this dispatcher using the given argument types.
    537     A signature returned and it is ensured that a compiled specialization
    538     is available for it.
    539     """
    540     template, pysig, args, kws = \
--> 541         self.dispatcher.get_call_template(args, kws)
    542     sig = template(context).apply(args, kws)
    543     if sig:

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:318, in _DispatcherBase.get_call_template(self, args, kws)
    316 # Ensure an overload is available
    317 if self._can_compile:
--> 318     self.compile(tuple(args))
    320 # Create function type for typing
    321 func_name = self.py_func.__name__

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
    903 with ev.trigger_event("numba:compile", data=ev_details):
    904     try:
--> 905         cres = self._compiler.compile(args, return_type)
    906     except errors.ForceLiteralArg as e:
    907         def folded(args, kws):

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:497, in BaseNativeLowering.run_pass(self, state)
    491     state['cr'] = _LowerResult(fndesc, call_helper,
    492                                cfunc=None, env=env)
    493 else:
    494     # Prepare for execution
    495     # Insert native function for use by other jitted-functions.
    496     # We also register its library to allow for inlining.
--> 497     cfunc = targetctx.get_executable(library, fndesc, env)
    498     targetctx.insert_user_function(cfunc, fndesc, [library])
    499     state['cr'] = _LowerResult(fndesc, call_helper,
    500                                cfunc=cfunc, env=env)

File C:\Code\aeon\venv\lib\site-packages\numba\core\cpu.py:239, in CPUContext.get_executable(self, library, fndesc, env)
    226 """
    227 Returns
    228 -------
   (...)
    236     an execution environment (from _dynfunc)
    237 """
    238 # Code generation
--> 239 fnptr = library.get_pointer_to_function(
    240     fndesc.llvm_cpython_wrapper_name)
    242 # Note: we avoid reusing the original docstring to avoid encoding
    243 # issues on Python 2, see issue #1908
    244 doc = "compiled wrapper for %r" % (fndesc.qualname,)

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:989, in JITCodeLibrary.get_pointer_to_function(self, name)
    975 def get_pointer_to_function(self, name):
    976     """
    977     Generate native code for function named *name* and return a pointer
    978     to the start of the function (as an integer).
   (...)
    987         - non-zero if the symbol is defined.
    988     """
--> 989     self._ensure_finalized()
    990     ee = self._codegen._engine
    991     if not ee.is_symbol_defined(name):

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:567, in CodeLibrary._ensure_finalized(self)
    565 def _ensure_finalized(self):
    566     if not self._finalized:
--> 567         self.finalize()

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:762, in CPUCodeLibrary.finalize(self)
    756         self._final_module.link_in(
    757             library._get_module_for_linking(), preserve=True,
    758         )
    760 # Optimize the module after all dependences are linked in above,
    761 # to allow for inlining.
--> 762 self._optimize_final_module()
    764 self._final_module.verify()
    765 self._finalize_final_module()

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:682, in CPUCodeLibrary._optimize_final_module(self)
    679 full_name = "Module passes (full optimization)"
    680 with self._recorded_timings.record(full_name):
    681     # The full optimisation suite is then run on the refop pruned IR
--> 682     self._codegen._mpm_full.run(self._final_module)

File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\passmanagers.py:698, in ModulePassManager.run(self, module, remarks_file, remarks_format, remarks_filter)
    683 """
    684 Run optimization passes on the given module.
    685
   (...)
    695     The filter that should be applied to the remarks output.
    696 """
    697 if remarks_file is None:
--> 698     return ffi.lib.LLVMPY_RunPassManager(self, module)
    699 else:
    700     r = ffi.lib.LLVMPY_RunPassManagerWithRemarks(
    701         self, module, _encode_string(remarks_format),
    702         _encode_string(remarks_filter),
    703         _encode_string(remarks_file))

File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\ffi.py:192, in _lib_fn_wrapper.__call__(self, *args, **kwargs)
    190 def __call__(self, *args, **kwargs):
    191     with self._lock:
--> 192         return self._cfn(*args, **kwargs)

KeyboardInterrupt:

2. 加载数据

[2]:
X_train, y_train = load_italy_power_demand(split="train")
X_test, y_test = load_italy_power_demand(split="test")
X_test = X_test[:50]
y_test = y_test[:50]

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

X_train_mv, y_train_mv = load_basic_motions(split="train")
X_test_mv, y_test_mv = load_basic_motions(split="test")

X_train_mv = X_train_mv[:50]
y_train_mv = y_train_mv[:50]
X_test_mv = X_test_mv[:50]
y_test_mv = y_test_mv[:50]

print(X_train_mv.shape, y_train_mv.shape, X_test_mv.shape, y_test_mv.shape)
(67, 1) (67,) (50, 1) (50,)
(40, 6) (40,) (40, 6) (40,)

3. 时间序列森林 (TSF)

TSF 是一组基于随机选择区间的汇总统计构建的树分类器。对于每棵树,随机选择 sqrt(n_timepoints) 个区间。从每个区间中提取每个时间序列的均值、标准差和斜率,并将它们连接成一个特征向量。然后使用这些新特征构建一棵树,并将其添加到集成中。

[3]:
tsf = TimeSeriesForestClassifier(n_estimators=50, random_state=47)
tsf.fit(X_train, y_train)

tsf_preds = tsf.predict(X_test)
print("TSF Accuracy: " + str(metrics.accuracy_score(y_test, tsf_preds)))
TSF Accuracy: 0.98

4. 随机间隔光谱集成 (RISE)

RISE 是一种基于树的区间集成方法,旨在对音频数据进行分类。与 TSF 不同,它为每棵树使用单个区间,并且使用频谱特征而不是汇总统计。

[4]:
rise = RandomIntervalSpectralEnsembleClassifier(n_estimators=50, random_state=47)
rise.fit(X_train, y_train)

rise_preds = rise.predict(X_test)
print("RISE Accuracy: " + str(metrics.accuracy_score(y_test, rise_preds)))
RISE Accuracy: 1.0

5. 监督时间序列森林(STSF 和 RSTSF)

STSF/RSTSF 对原始的 TSF 算法进行了多项调整。选择区间的监督方法取代了随机选择。特征是从周期图和一阶差分中的附加表示生成的区间中提取的。提取的汇总统计量包括中位数、最小值、最大值和四分位距。

[1]:
stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
stsf.fit(X_train, y_train)

stsf_preds = stsf.predict(X_test)
print("STSF Accuracy: " + str(metrics.accuracy_score(y_test, stsf_preds)))

rstsf = RSTSF(n_estimators=20)
rstsf.fit(X_train, y_train)

rstsf_preds = rstsf.predict(X_test)
print("RSTSF Accuracy: " + str(metrics.accuracy_score(y_test, rstsf_preds)))
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 1
----> 1 stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
      2 stsf.fit(X_train, y_train)
      4 stsf_preds = stsf.predict(X_test)

NameError: name 'SupervisedTimeSeriesForest' is not defined

6. 规范间隔森林 (CIF)

CIF 是从 TSF 算法扩展而来的。除了 TSF 使用的 3 个摘要统计量外,CIF 还利用了 Catch22 [7] 转换的特征。为了增加集成的多样性,每棵树中 TSF 和 catch22 属性的数量是随机子采样的。

单变量

[6]:
cif = CanonicalIntervalForestClassifier(
    n_estimators=50, att_subsample_size=8, random_state=47
)
cif.fit(X_train, y_train)

cif_preds = cif.predict(X_test)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test, cif_preds)))
CIF Accuracy: 0.98

多变量

[7]:
cif_m = CanonicalIntervalForestClassifier(
    n_estimators=50, att_subsample_size=8, random_state=47
)
cif_m.fit(X_train_mv, y_train_mv)

cif_m_preds = cif_m.predict(X_test_mv)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, cif_m_preds)))
CIF Accuracy: 1.0

6. 多样表示规范区间森林 (DrCIF)

DrCIF 使用了 STSF 中使用的周期图和差异表示法,以及 CIF 中的附加摘要统计信息。

单变量

[8]:
drcif = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif.fit(X_train, y_train)

drcif_preds = drcif.predict(X_test)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test, drcif_preds)))
DrCIF Accuracy: 0.98

多变量

[9]:
drcif_m = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif_m.fit(X_train_mv, y_train_mv)

drcif_m_preds = drcif_m.predict(X_test_mv)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, drcif_m_preds)))
DrCIF Accuracy: 1.0

7. 量化

QUANT 是一个基于分位数特征的快速区间分类器

[ ]:
quant = QUANTClassifier(interval_depth=1)
quant.fit(X_train, y_train)
print("QUANT accuracy =", quant.score(X_test, y_test))

在UCR单变量数据集上的表现

您可以按如下方式找到基于区间的分类器。

[1]:
from aeon.utils.discovery import all_estimators

est = all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
for c in est:
    print(c)
('CanonicalIntervalForestClassifier', <class 'aeon.classification.interval_based._cif.CanonicalIntervalForestClassifier'>)
('DrCIFClassifier', <class 'aeon.classification.interval_based._drcif.DrCIFClassifier'>)
('IntervalForestClassifier', <class 'aeon.classification.interval_based._interval_forest.IntervalForestClassifier'>)
('QUANTClassifier', <class 'aeon.classification.interval_based._quant.QUANTClassifier'>)
('RSTSF', <class 'aeon.classification.interval_based._rstsf.RSTSF'>)
('RandomIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.RandomIntervalClassifier'>)
('RandomIntervalSpectralEnsembleClassifier', <class 'aeon.classification.interval_based._rise.RandomIntervalSpectralEnsembleClassifier'>)
('SupervisedIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.SupervisedIntervalClassifier'>)
('SupervisedTimeSeriesForest', <class 'aeon.classification.interval_based._stsf.SupervisedTimeSeriesForest'>)
('TimeSeriesForestClassifier', <class 'aeon.classification.interval_based._tsf.TimeSeriesForestClassifier'>)
[2]:
from aeon.benchmarking.results_loaders import get_estimator_results_as_array
from aeon.datasets.tsc_datasets import univariate

names = [t[0].replace("Classifier", "") for t in est]
names.remove("IntervalForest")  # Base class
names.remove("RandomInterval")  # Pipeline
names.remove("SupervisedInterval")  # Pipeline
results, present_names = get_estimator_results_as_array(
    names, univariate, include_missing=False
)
results.shape
[2]:
(112, 7)
[3]:
from aeon.visualisation import plot_boxplot, plot_critical_difference

plot_critical_difference(results, names)
[3]:
(<Figure size 600x260 with 1 Axes>, <Axes: >)
../../_images/examples_classification_interval_based_24_1.png
[4]:
plot_boxplot(results, names, relative=True)
[4]:
(<Figure size 1000x600 with 1 Axes>, <Axes: >)
../../_images/examples_classification_interval_based_25_1.png

参考文献:

[1] Deng, H. 等人. (2013). 用于分类和特征提取的时间序列森林. 信息科学, 239, 142-153.

[2] Flynn, M 等人. (2019). 合同随机区间谱集成 (c-RISE): 分类器收缩对准确性的影响. 在混合人工智能系统国际会议上 (pp. 381-392).

[3] Cabello, N. 等人. (2020). 通过监督区间搜索实现快速准确的时间序列分类. 发表于 IEEE 国际数据挖掘会议.

[4] Cabello, N. 等人. (2024). 通过随机化实现快速、准确且可解释的时间序列分类. 数据挖掘与知识发现 38: https://link.springer.com/article/10.1007/s10618-023-00978-w

[5] Middlehurst, M. 等人. (2020). 用于时间序列分类的规范区间森林(CIF)分类器. IEEE 国际数据挖掘会议 https://ieeexplore.ieee .org/document/9378424 arXiv 版本. https://arxiv.org/abs/2008.09172 [6] Dempster, A. (2024). QUANT: 一种用于时间序列分类的极简区间方法. 数据挖掘与知识发现 38: https://link.springer.com/article/10.1007/s10618-024-01036-9 [7] Lubba, C. 等人. (2019). catch22: 规范时间序列特征. 数据挖掘与知识发现, 33(6), 1821-1852.


使用nbsphinx生成。Jupyter笔记本可以在这里找到。