Source code for lib.sedna.core.incremental_learning.incremental_learning

# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy

from sedna.common.file_ops import FileOps
from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus
from sedna.common.class_factory import ClassFactory, ClassType
from sedna.core.base import JobBase

__all__ = ("IncrementalLearning",)


[docs]class IncrementalLearning(JobBase): """ Incremental learning is a method of machine learning in which input data is continuously used to extend the existing model's knowledge i.e. to further train the model. It represents a dynamic technique of supervised learning and unsupervised learning that can be applied when training data becomes available gradually over time. Sedna provide the related interfaces for application development. Parameters ---------- estimator : Instance An instance with the high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. hard_example_mining : Dict HEM algorithms with parameters which has registered to ClassFactory, see `sedna.algorithms.hard_example_mining` for more detail. Examples -------- >>> Estimator = keras.models.Sequential() >>> il_model = IncrementalLearning( estimator=Estimator, hard_example_mining={ "method": "IBT", "param": { "threshold_img": 0.9 } } ) Notes ----- Sedna provide an interface call `get_hem_algorithm_from_config` to build the `hard_example_mining` parameter from CRD definition. """ def __init__(self, estimator, hard_example_mining: dict = None): super(IncrementalLearning, self).__init__(estimator=estimator) self.model_urls = self.get_parameters( "MODEL_URLS") # use in evaluation self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value FileOps.clean_folder([self.config.model_url], clean=False) self.hard_example_mining_algorithm = None if not hard_example_mining: hard_example_mining = self.get_hem_algorithm_from_config() if hard_example_mining: hem = hard_example_mining.get("method", "IBT") hem_parameters = hard_example_mining.get("param", {}) self.hard_example_mining_algorithm = ClassFactory.get_cls( ClassType.HEM, hem )(**hem_parameters) @classmethod
[docs] def get_hem_algorithm_from_config(cls, **param): """ get the `algorithm` name and `param` of hard_example_mining from crd Parameters ---------- param : Dict update value in parameters of hard_example_mining Returns ------- dict e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} Examples -------- >>> IncrementalLearning.get_hem_algorithm_from_config( threshold_img=0.9 ) {"method": "IBT", "param": {"threshold_img": 0.9}} """ return cls.parameters.get_algorithm_from_api( algorithm="HEM", **param
)
[docs] def train(self, train_data, valid_data=None, post_process=None, **kwargs): """ Training task for IncrementalLearning Parameters ---------- train_data: BaseDataSource datasource use for train, see `sedna.datasources.BaseDataSource` for more detail. valid_data: BaseDataSource datasource use for evaluation, see `sedna.datasources.BaseDataSource` for more detail. post_process: function or a registered method effected after `estimator` training. kwargs: Dict parameters for `estimator` training, Like: `early_stopping_rounds` in Xgboost.XGBClassifier Returns ------- estimator """ callback_func = None if post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) res = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs) model_paths = self.estimator.save(self.model_path) task_info_res = self.estimator.model_info( model_paths, result=res, relpath=self.config.data_path_prefix) self.report_task_info( None, K8sResourceKindStatus.COMPLETED.value, task_info_res) return callback_func( self.estimator) if callback_func else self.estimator
[docs] def inference(self, data=None, post_process=None, **kwargs): """ Inference task for IncrementalLearning Parameters ---------- data: BaseDataSource datasource use for inference, see `sedna.datasources.BaseDataSource` for more detail. post_process: function or a registered method effected after `estimator` inference. kwargs: Dict parameters for `estimator` inference, Like: `ntree_limit` in Xgboost.XGBClassifier Returns ------- inference result : object result after post_process : object if is hard sample : bool """ if not self.estimator.has_load: self.estimator.load(self.model_path) callback_func = None if callable(post_process): callback_func = post_process elif post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) infer_res = self.estimator.predict(data, **kwargs) if callback_func: res = callback_func( deepcopy(infer_res) # Prevent infer_result from being modified ) else: res = infer_res is_hard_example = False if self.hard_example_mining_algorithm: is_hard_example = self.hard_example_mining_algorithm(res) return infer_res, res, is_hard_example
[docs] def evaluate(self, data, post_process=None, **kwargs): """ Evaluate task for IncrementalLearning Parameters ---------- data: BaseDataSource datasource use for evaluation, see `sedna.datasources.BaseDataSource` for more detail. post_process: function or a registered method effected after `estimator` evaluation. kwargs: Dict parameters for `estimator` evaluate, Like: `metric_name` in Xgboost.XGBClassifier Returns ------- evaluate metrics : List """ callback_func = None if callable(post_process): callback_func = post_process elif post_process: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) final_res = [] all_models = [] if self.model_urls: all_models = self.model_urls.split(";") elif self.config.model_url: all_models.append(self.config.model_url) for model_url in all_models: if not model_url.strip(): continue self.estimator.model_save_path = model_url res = self.estimator.evaluate( data=data, model_path=model_url, **kwargs) if callback_func: res = callback_func(res) self.log.info(f"Evaluation with {model_url} : {res} ") task_info_res = self.estimator.model_info( model_url, result=res, relpath=self.config.data_path_prefix) if isinstance( task_info_res, (list, tuple) ) and len(task_info_res): task_info_res = list(task_info_res)[0] final_res.append(task_info_res) self.report_task_info(None, K8sResourceKindStatus.COMPLETED.value, final_res, kind="eval") return final_res