基于间隔的时间序列分类在aeon中¶
基于间隔的方法着眼于整个序列的相位依赖间隔,从选定的子序列中计算摘要统计量以用于分类。
当前基于间隔的方法在aeon中实现。时间序列森林(TSF)[1],随机间隔谱集成(RISE)[2],监督时间序列森林(STSF)[3]和随机STSF [4],规范间隔森林(CIF)[5],多样化表示规范间隔森林(DrCIF)和QUANT [6]。大多数具有分类多变量序列的能力。
在本笔记本中,我们将演示如何在ItalyPowerDemand和BasicMotions数据集上使用这些分类器。
设置¶
我们可以像这样列出此类别中的所有分类器
[3]:
import warnings
from sklearn import metrics
from aeon.classification.interval_based import (
RSTSF,
CanonicalIntervalForestClassifier,
DrCIFClassifier,
QUANTClassifier,
RandomIntervalSpectralEnsembleClassifier,
SupervisedTimeSeriesForest,
TimeSeriesForestClassifier,
)
from aeon.datasets import load_basic_motions, load_italy_power_demand
from aeon.utils.discovery import all_estimators
warnings.filterwarnings("ignore")
all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[3], line 18
15 from aeon.utils.discovery import all_estimators
17 warnings.filterwarnings("ignore")
---> 18 all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
File C:\Code\aeon\aeon\utils\discovery.py:121, in all_estimators(type_filter, exclude_types, tag_filter, exclude_tags, include_sklearn, return_names)
116 if (
117 any(part in modules_to_ignore for part in module_parts)
118 or "._" in module_name
119 ):
120 continue
--> 121 module = import_module(module_name)
123 classes = inspect.getmembers(module, inspect.isclass)
124 # skip private estimators and those not implemented in aeon
File ~\AppData\Local\Programs\Python\Python39\lib\importlib\__init__.py:127, in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
File <frozen importlib._bootstrap>:1030, in _gcd_import(name, package, level)
File <frozen importlib._bootstrap>:1007, in _find_and_load(name, import_)
File <frozen importlib._bootstrap>:986, in _find_and_load_unlocked(name, import_)
File <frozen importlib._bootstrap>:680, in _load_unlocked(spec)
File <frozen importlib._bootstrap_external>:850, in exec_module(self, module)
File <frozen importlib._bootstrap>:228, in _call_with_frames_removed(f, *args, **kwds)
File C:\Code\aeon\aeon\classification\convolution_based\__init__.py:12
1 """Convolution-based time series classifiers."""
3 __all__ = [
4 "RocketClassifier",
5 "MiniRocketClassifier",
(...)
9 "MultiRocketHydraClassifier",
10 ]
---> 12 from aeon.classification.convolution_based._arsenal import Arsenal
13 from aeon.classification.convolution_based._hydra import HydraClassifier
14 from aeon.classification.convolution_based._minirocket import MiniRocketClassifier
File C:\Code\aeon\aeon\classification\convolution_based\_arsenal.py:20
18 from aeon.base._base import _clone_estimator
19 from aeon.classification.base import BaseClassifier
---> 20 from aeon.transformations.collection.convolution_based import (
21 MiniRocket,
22 MultiRocket,
23 Rocket,
24 )
27 class Arsenal(BaseClassifier):
28 """
29 Arsenal ensemble.
30
(...)
120 >>> y_pred = clf.predict(X_test)
121 """
File C:\Code\aeon\aeon\transformations\collection\convolution_based\__init__.py:13
11 from ._hydra import HydraTransformer
12 from ._minirocket import MiniRocket
---> 13 from ._minirocket_mv import MiniRocketMultivariateVariable
14 from ._multirocket import MultiRocket
15 from ._rocket import Rocket
File C:\Code\aeon\aeon\transformations\collection\convolution_based\_minirocket_mv.py:303
290 return X_2d_t, lengths
293 # code below from the orignal authors: https://github.com/angus924/minirocket
296 @njit(
297 "float32[:](float32[:,:],int32[:],int32[:],int32[:],int32[:],int32[:],float32[:],"
298 "optional(int32))",
299 fastmath=True,
300 parallel=False,
301 cache=True,
302 )
--> 303 def _fit_biases_multi_var(
304 X,
305 L,
306 num_channels_per_combination,
307 channel_indices,
308 dilations,
309 num_features_per_dilation,
310 quantiles,
311 seed,
312 ):
313 if seed is not None:
314 np.random.seed(seed)
File C:\Code\aeon\venv\lib\site-packages\numba\core\decorators.py:232, in _jit.<locals>.wrapper(func)
230 with typeinfer.register_dispatcher(disp):
231 for sig in sigs:
--> 232 disp.compile(sig)
233 disp.disable_compile()
234 return disp
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
903 with ev.trigger_event("numba:compile", data=ev_details):
904 try:
--> 905 cres = self._compiler.compile(args, return_type)
906 except errors.ForceLiteralArg as e:
907 def folded(args, kws):
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
79 def compile(self, args, return_type):
---> 80 status, retval = self._compile_cached(args, return_type)
81 if status:
82 return retval
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91 pass
93 try:
---> 94 retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96 self._failed_cache[key] = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108 self.targetdescr.target_context,
109 impl,
110 args=args, return_type=return_type,
111 flags=flags, locals=self.locals,
112 pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740 compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743 args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472 pm.run(self.state)
473 if self.state.cr is not None:
474 break
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
106 """
107 Type inference and legalization
108 """
109 with fallback_context(state, 'Function "%s" failed type inference'
110 % (state.func_id.func_name,)):
111 # Type inference
--> 112 typemap, return_type, calltypes, errs = type_inference_stage(
113 state.typingctx,
114 state.targetctx,
115 state.func_ir,
116 state.args,
117 state.return_type,
118 state.locals,
119 raise_errors=self._raise_errors)
120 state.typemap = typemap
121 # save errors in case of partial typing
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
91 infer.build_constraint()
92 # return errors in case of partial typing
---> 93 errs = infer.propagate(raise_errors=raise_errors)
94 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
96 return _TypingResults(typemap, restype, calltypes, errs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
1080 oldtoken = newtoken
1081 # Errors can appear when the type set is incomplete; only
1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
1084 newtoken = self.get_state_token()
1085 self.debug.propagate_finished()
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
158 lineno=loc.line):
159 try:
--> 160 constraint(typeinfer)
161 except ForceLiteralArg as e:
162 errors.append(e)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
581 fnty = typevars[self.func].getone()
582 with new_error_context("resolving callee type: {0}", fnty):
--> 583 self.resolve(typeinfer, typevars, fnty)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
604 fnty = fnty.instance_type
605 try:
--> 606 sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
607 except ForceLiteralArg as e:
608 # Adjust for bound methods
609 folding_args = ((fnty.this,) + tuple(self.args)
610 if isinstance(fnty, types.BoundFunction)
611 else self.args)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1577, in TypeInferer.resolve_call(self, fnty, pos_args, kw_args)
1574 return sig
1575 else:
1576 # Normal non-recursive call
-> 1577 return self.context.resolve_function_type(fnty, pos_args, kw_args)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:196, in BaseContext.resolve_function_type(self, func, args, kws)
194 # Prefer user definition first
195 try:
--> 196 res = self._resolve_user_function_type(func, args, kws)
197 except errors.TypingError as e:
198 # Capture any typing error
199 last_exception = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:248, in BaseContext._resolve_user_function_type(self, func, args, kws, literals)
244 return self.resolve_function_type(func_type, args, kws)
246 if isinstance(func, types.Callable):
247 # XXX fold this into the __call__ attribute logic?
--> 248 return func.get_call_type(self, args, kws)
File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:308, in BaseFunction.get_call_type(self, context, args, kws)
305 nolitargs = tuple([_unlit_non_poison(a) for a in args])
306 nolitkws = {k: _unlit_non_poison(v)
307 for k, v in kws.items()}
--> 308 sig = temp.apply(nolitargs, nolitkws)
309 except Exception as e:
310 if (utils.use_new_style_errors() and not
311 isinstance(e, errors.NumbaError)):
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:350, in AbstractTemplate.apply(self, args, kws)
348 def apply(self, args, kws):
349 generic = getattr(self, "generic")
--> 350 sig = generic(args, kws)
351 # Enforce that *generic()* must return None or Signature
352 if sig is not None:
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:613, in _OverloadFunctionTemplate.generic(self, args, kws)
607 """
608 Type the overloaded function by compiling the appropriate
609 implementation for the given args.
610 """
611 from numba.core.typed_passes import PreLowerStripPhis
--> 613 disp, new_args = self._get_impl(args, kws)
614 if disp is None:
615 return
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:712, in _OverloadFunctionTemplate._get_impl(self, args, kws)
708 except KeyError:
709 # pass and try outside the scope so as to not have KeyError with a
710 # nested addition error in the case the _build_impl fails
711 pass
--> 712 impl, args = self._build_impl(cache_key, args, kws)
713 return impl, args
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:816, in _OverloadFunctionTemplate._build_impl(self, cache_key, args, kws)
814 # Make sure that the implementation can be fully compiled
815 disp_type = types.Dispatcher(disp)
--> 816 disp_type.get_call_type(self.context, args, kws)
817 if cache_key is not None:
818 self._impl_cache[cache_key] = disp, args
File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:541, in Dispatcher.get_call_type(self, context, args, kws)
534 def get_call_type(self, context, args, kws):
535 """
536 Resolve a call to this dispatcher using the given argument types.
537 A signature returned and it is ensured that a compiled specialization
538 is available for it.
539 """
540 template, pysig, args, kws = \
--> 541 self.dispatcher.get_call_template(args, kws)
542 sig = template(context).apply(args, kws)
543 if sig:
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:318, in _DispatcherBase.get_call_template(self, args, kws)
316 # Ensure an overload is available
317 if self._can_compile:
--> 318 self.compile(tuple(args))
320 # Create function type for typing
321 func_name = self.py_func.__name__
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
903 with ev.trigger_event("numba:compile", data=ev_details):
904 try:
--> 905 cres = self._compiler.compile(args, return_type)
906 except errors.ForceLiteralArg as e:
907 def folded(args, kws):
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
79 def compile(self, args, return_type):
---> 80 status, retval = self._compile_cached(args, return_type)
81 if status:
82 return retval
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91 pass
93 try:
---> 94 retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96 self._failed_cache[key] = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108 self.targetdescr.target_context,
109 impl,
110 args=args, return_type=return_type,
111 flags=flags, locals=self.locals,
112 pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740 compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743 args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472 pm.run(self.state)
473 if self.state.cr is not None:
474 break
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:497, in BaseNativeLowering.run_pass(self, state)
491 state['cr'] = _LowerResult(fndesc, call_helper,
492 cfunc=None, env=env)
493 else:
494 # Prepare for execution
495 # Insert native function for use by other jitted-functions.
496 # We also register its library to allow for inlining.
--> 497 cfunc = targetctx.get_executable(library, fndesc, env)
498 targetctx.insert_user_function(cfunc, fndesc, [library])
499 state['cr'] = _LowerResult(fndesc, call_helper,
500 cfunc=cfunc, env=env)
File C:\Code\aeon\venv\lib\site-packages\numba\core\cpu.py:239, in CPUContext.get_executable(self, library, fndesc, env)
226 """
227 Returns
228 -------
(...)
236 an execution environment (from _dynfunc)
237 """
238 # Code generation
--> 239 fnptr = library.get_pointer_to_function(
240 fndesc.llvm_cpython_wrapper_name)
242 # Note: we avoid reusing the original docstring to avoid encoding
243 # issues on Python 2, see issue #1908
244 doc = "compiled wrapper for %r" % (fndesc.qualname,)
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:989, in JITCodeLibrary.get_pointer_to_function(self, name)
975 def get_pointer_to_function(self, name):
976 """
977 Generate native code for function named *name* and return a pointer
978 to the start of the function (as an integer).
(...)
987 - non-zero if the symbol is defined.
988 """
--> 989 self._ensure_finalized()
990 ee = self._codegen._engine
991 if not ee.is_symbol_defined(name):
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:567, in CodeLibrary._ensure_finalized(self)
565 def _ensure_finalized(self):
566 if not self._finalized:
--> 567 self.finalize()
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:762, in CPUCodeLibrary.finalize(self)
756 self._final_module.link_in(
757 library._get_module_for_linking(), preserve=True,
758 )
760 # Optimize the module after all dependences are linked in above,
761 # to allow for inlining.
--> 762 self._optimize_final_module()
764 self._final_module.verify()
765 self._finalize_final_module()
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:682, in CPUCodeLibrary._optimize_final_module(self)
679 full_name = "Module passes (full optimization)"
680 with self._recorded_timings.record(full_name):
681 # The full optimisation suite is then run on the refop pruned IR
--> 682 self._codegen._mpm_full.run(self._final_module)
File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\passmanagers.py:698, in ModulePassManager.run(self, module, remarks_file, remarks_format, remarks_filter)
683 """
684 Run optimization passes on the given module.
685
(...)
695 The filter that should be applied to the remarks output.
696 """
697 if remarks_file is None:
--> 698 return ffi.lib.LLVMPY_RunPassManager(self, module)
699 else:
700 r = ffi.lib.LLVMPY_RunPassManagerWithRemarks(
701 self, module, _encode_string(remarks_format),
702 _encode_string(remarks_filter),
703 _encode_string(remarks_file))
File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\ffi.py:192, in _lib_fn_wrapper.__call__(self, *args, **kwargs)
190 def __call__(self, *args, **kwargs):
191 with self._lock:
--> 192 return self._cfn(*args, **kwargs)
KeyboardInterrupt:
2. 加载数据¶
[2]:
X_train, y_train = load_italy_power_demand(split="train")
X_test, y_test = load_italy_power_demand(split="test")
X_test = X_test[:50]
y_test = y_test[:50]
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
X_train_mv, y_train_mv = load_basic_motions(split="train")
X_test_mv, y_test_mv = load_basic_motions(split="test")
X_train_mv = X_train_mv[:50]
y_train_mv = y_train_mv[:50]
X_test_mv = X_test_mv[:50]
y_test_mv = y_test_mv[:50]
print(X_train_mv.shape, y_train_mv.shape, X_test_mv.shape, y_test_mv.shape)
(67, 1) (67,) (50, 1) (50,)
(40, 6) (40,) (40, 6) (40,)
3. 时间序列森林 (TSF)¶
TSF 是一组基于随机选择区间的汇总统计构建的树分类器。对于每棵树,随机选择 sqrt(n_timepoints) 个区间。从每个区间中提取每个时间序列的均值、标准差和斜率,并将它们连接成一个特征向量。然后使用这些新特征构建一棵树,并将其添加到集成中。
[3]:
tsf = TimeSeriesForestClassifier(n_estimators=50, random_state=47)
tsf.fit(X_train, y_train)
tsf_preds = tsf.predict(X_test)
print("TSF Accuracy: " + str(metrics.accuracy_score(y_test, tsf_preds)))
TSF Accuracy: 0.98
4. 随机间隔光谱集成 (RISE)¶
RISE 是一种基于树的区间集成方法,旨在对音频数据进行分类。与 TSF 不同,它为每棵树使用单个区间,并且使用频谱特征而不是汇总统计。
[4]:
rise = RandomIntervalSpectralEnsembleClassifier(n_estimators=50, random_state=47)
rise.fit(X_train, y_train)
rise_preds = rise.predict(X_test)
print("RISE Accuracy: " + str(metrics.accuracy_score(y_test, rise_preds)))
RISE Accuracy: 1.0
5. 监督时间序列森林(STSF 和 RSTSF)¶
STSF/RSTSF 对原始的 TSF 算法进行了多项调整。选择区间的监督方法取代了随机选择。特征是从周期图和一阶差分中的附加表示生成的区间中提取的。提取的汇总统计量包括中位数、最小值、最大值和四分位距。
[1]:
stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
stsf.fit(X_train, y_train)
stsf_preds = stsf.predict(X_test)
print("STSF Accuracy: " + str(metrics.accuracy_score(y_test, stsf_preds)))
rstsf = RSTSF(n_estimators=20)
rstsf.fit(X_train, y_train)
rstsf_preds = rstsf.predict(X_test)
print("RSTSF Accuracy: " + str(metrics.accuracy_score(y_test, rstsf_preds)))
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[1], line 1
----> 1 stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
2 stsf.fit(X_train, y_train)
4 stsf_preds = stsf.predict(X_test)
NameError: name 'SupervisedTimeSeriesForest' is not defined
6. 规范间隔森林 (CIF)¶
CIF 是从 TSF 算法扩展而来的。除了 TSF 使用的 3 个摘要统计量外,CIF 还利用了 Catch22 [7] 转换的特征。为了增加集成的多样性,每棵树中 TSF 和 catch22 属性的数量是随机子采样的。
单变量¶
[6]:
cif = CanonicalIntervalForestClassifier(
n_estimators=50, att_subsample_size=8, random_state=47
)
cif.fit(X_train, y_train)
cif_preds = cif.predict(X_test)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test, cif_preds)))
CIF Accuracy: 0.98
多变量¶
[7]:
cif_m = CanonicalIntervalForestClassifier(
n_estimators=50, att_subsample_size=8, random_state=47
)
cif_m.fit(X_train_mv, y_train_mv)
cif_m_preds = cif_m.predict(X_test_mv)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, cif_m_preds)))
CIF Accuracy: 1.0
6. 多样表示规范区间森林 (DrCIF)¶
DrCIF 使用了 STSF 中使用的周期图和差异表示法,以及 CIF 中的附加摘要统计信息。
单变量¶
[8]:
drcif = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif.fit(X_train, y_train)
drcif_preds = drcif.predict(X_test)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test, drcif_preds)))
DrCIF Accuracy: 0.98
多变量¶
[9]:
drcif_m = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif_m.fit(X_train_mv, y_train_mv)
drcif_m_preds = drcif_m.predict(X_test_mv)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, drcif_m_preds)))
DrCIF Accuracy: 1.0
7. 量化¶
QUANT 是一个基于分位数特征的快速区间分类器
[ ]:
quant = QUANTClassifier(interval_depth=1)
quant.fit(X_train, y_train)
print("QUANT accuracy =", quant.score(X_test, y_test))
在UCR单变量数据集上的表现¶
您可以按如下方式找到基于区间的分类器。
[1]:
from aeon.utils.discovery import all_estimators
est = all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
for c in est:
print(c)
('CanonicalIntervalForestClassifier', <class 'aeon.classification.interval_based._cif.CanonicalIntervalForestClassifier'>)
('DrCIFClassifier', <class 'aeon.classification.interval_based._drcif.DrCIFClassifier'>)
('IntervalForestClassifier', <class 'aeon.classification.interval_based._interval_forest.IntervalForestClassifier'>)
('QUANTClassifier', <class 'aeon.classification.interval_based._quant.QUANTClassifier'>)
('RSTSF', <class 'aeon.classification.interval_based._rstsf.RSTSF'>)
('RandomIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.RandomIntervalClassifier'>)
('RandomIntervalSpectralEnsembleClassifier', <class 'aeon.classification.interval_based._rise.RandomIntervalSpectralEnsembleClassifier'>)
('SupervisedIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.SupervisedIntervalClassifier'>)
('SupervisedTimeSeriesForest', <class 'aeon.classification.interval_based._stsf.SupervisedTimeSeriesForest'>)
('TimeSeriesForestClassifier', <class 'aeon.classification.interval_based._tsf.TimeSeriesForestClassifier'>)
[2]:
from aeon.benchmarking.results_loaders import get_estimator_results_as_array
from aeon.datasets.tsc_datasets import univariate
names = [t[0].replace("Classifier", "") for t in est]
names.remove("IntervalForest") # Base class
names.remove("RandomInterval") # Pipeline
names.remove("SupervisedInterval") # Pipeline
results, present_names = get_estimator_results_as_array(
names, univariate, include_missing=False
)
results.shape
[2]:
(112, 7)
[3]:
from aeon.visualisation import plot_boxplot, plot_critical_difference
plot_critical_difference(results, names)
[3]:
(<Figure size 600x260 with 1 Axes>, <Axes: >)
[4]:
plot_boxplot(results, names, relative=True)
[4]:
(<Figure size 1000x600 with 1 Axes>, <Axes: >)
参考文献:¶
[1] Deng, H. 等人. (2013). 用于分类和特征提取的时间序列森林. 信息科学, 239, 142-153.
[2] Flynn, M 等人. (2019). 合同随机区间谱集成 (c-RISE): 分类器收缩对准确性的影响. 在混合人工智能系统国际会议上 (pp. 381-392).
[3] Cabello, N. 等人. (2020). 通过监督区间搜索实现快速准确的时间序列分类. 发表于 IEEE 国际数据挖掘会议.
[4] Cabello, N. 等人. (2024). 通过随机化实现快速、准确且可解释的时间序列分类. 数据挖掘与知识发现 38: https://link.springer.com/article/10.1007/s10618-023-00978-w
[5] Middlehurst, M. 等人. (2020). 用于时间序列分类的规范区间森林(CIF)分类器. IEEE 国际数据挖掘会议 https://ieeexplore.ieee .org/document/9378424 arXiv 版本. https://arxiv.org/abs/2008.09172 [6] Dempster, A. (2024). QUANT: 一种用于时间序列分类的极简区间方法. 数据挖掘与知识发现 38: https://link.springer.com/article/10.1007/s10618-024-01036-9 [7] Lubba, C. 等人. (2019). catch22: 规范时间序列特征. 数据挖掘与知识发现, 33(6), 1821-1852.