lib.sedna.core.incremental_learning.incremental_learning
¶
Module Contents¶
Classes¶
Incremental learning is a method of machine learning in which input data |
- class lib.sedna.core.incremental_learning.incremental_learning.IncrementalLearning(estimator, hard_example_mining: dict = None)[source]¶
Bases:
sedna.core.base.JobBase
Incremental learning is a method of machine learning in which input data is continuously used to extend the existing model’s knowledge i.e. to further train the model. It represents a dynamic technique of supervised learning and unsupervised learning that can be applied when training data becomes available gradually over time.
Sedna provide the related interfaces for application development.
- Parameters:
estimator (Instance) – An instance with the high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model.
hard_example_mining (Dict) – HEM algorithms with parameters which has registered to ClassFactory, see sedna.algorithms.hard_example_mining for more detail.
Examples
>>> Estimator = keras.models.Sequential() >>> il_model = IncrementalLearning( estimator=Estimator, hard_example_mining={ "method": "IBT", "param": { "threshold_img": 0.9 } } )
Notes
Sedna provide an interface call get_hem_algorithm_from_config to build the hard_example_mining parameter from CRD definition.
- classmethod get_hem_algorithm_from_config(**param)[source]¶
get the algorithm name and param of hard_example_mining from crd
- Parameters:
param (Dict) – update value in parameters of hard_example_mining
- Returns:
e.g.: {“method”: “IBT”, “param”: {“threshold_img”: 0.5}}
- Return type:
dict
Examples
>>> IncrementalLearning.get_hem_algorithm_from_config( threshold_img=0.9 ) {"method": "IBT", "param": {"threshold_img": 0.9}}
- train(train_data, valid_data=None, post_process=None, **kwargs)[source]¶
Training task for IncrementalLearning
- Parameters:
train_data (BaseDataSource) – datasource use for train, see sedna.datasources.BaseDataSource for more detail.
valid_data (BaseDataSource) – datasource use for evaluation, see sedna.datasources.BaseDataSource for more detail.
post_process (function or a registered method) – effected after estimator training.
kwargs (Dict) – parameters for estimator training, Like: early_stopping_rounds in Xgboost.XGBClassifier
- Return type:
estimator
- inference(data=None, post_process=None, **kwargs)[source]¶
Inference task for IncrementalLearning
- Parameters:
data (BaseDataSource) – datasource use for inference, see sedna.datasources.BaseDataSource for more detail.
post_process (function or a registered method) – effected after estimator inference.
kwargs (Dict) – parameters for estimator inference, Like: ntree_limit in Xgboost.XGBClassifier
- Returns:
inference result (object)
result after post_process (object)
if is hard sample (bool)
- evaluate(data, post_process=None, **kwargs)[source]¶
Evaluate task for IncrementalLearning
- Parameters:
data (BaseDataSource) – datasource use for evaluation, see sedna.datasources.BaseDataSource for more detail.
post_process (function or a registered method) – effected after estimator evaluation.
kwargs (Dict) – parameters for estimator evaluate, Like: metric_name in Xgboost.XGBClassifier
- Returns:
evaluate metrics
- Return type:
List