lib.sedna.algorithms.multi_task_learning.multi_task_learning
¶
Multiple task transfer learning algorithms
Module Contents¶
Classes¶
An auto machine learning framework for edge-cloud multitask learning |
- class lib.sedna.algorithms.multi_task_learning.multi_task_learning.MulTaskLearning(estimator=None, task_definition=None, task_relationship_discovery=None, task_mining=None, task_remodeling=None, inference_integrate=None)[source]¶
An auto machine learning framework for edge-cloud multitask learning
See also
Train
Data + Estimator -> Task Definition -> Task Relationship Discovery -> Feature Engineering -> Training
Inference
Data -> Task Allocation -> Task Mining -> Feature Engineering -> Task Remodeling -> Inference
- Parameters:
estimator (Instance) – An instance with the high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model.
task_definition (Dict) – Divide multiple tasks based on data, see task_jobs.task_definition for more detail.
task_relationship_discovery (Dict) – Discover relationships between all tasks, see task_jobs.task_relationship_discovery for more detail.
task_mining (Dict) – Mining tasks of inference sample, see task_jobs.task_mining for more detail.
task_remodeling (Dict) – Remodeling tasks based on their relationships, see task_jobs.task_remodeling for more detail.
inference_integrate (Dict) – Integrate the inference results of all related tasks, see task_jobs.inference_integrate for more detail.
Examples
>>> from xgboost import XGBClassifier >>> from sedna.algorithms.multi_task_learning import MulTaskLearning >>> estimator = XGBClassifier(objective="binary:logistic") >>> task_definition = { "method": "TaskDefinitionByDataAttr", "param": {"attribute": ["season", "city"]} } >>> task_relationship_discovery = { "method": "DefaultTaskRelationDiscover", "param": {} } >>> task_mining = { "method": "TaskMiningByDataAttr", "param": {"attribute": ["season", "city"]} } >>> task_remodeling = None >>> inference_integrate = { "method": "DefaultInferenceIntegrate", "param": {} } >>> mul_task_instance = MulTaskLearning( estimator=estimator, task_definition=task_definition, task_relationship_discovery=task_relationship_discovery, task_mining=task_mining, task_remodeling=task_remodeling, inference_integrate=inference_integrate )
Notes
All method defined under task_jobs and registered in ClassFactory.
- train(train_data: sedna.datasources.BaseDataSource, valid_data: sedna.datasources.BaseDataSource = None, post_process=None, **kwargs)[source]¶
fit for update the knowledge based on training data.
- Parameters:
train_data (BaseDataSource) – Train data, see sedna.datasources.BaseDataSource for more detail.
valid_data (BaseDataSource) – Valid data, BaseDataSource or None.
post_process (function) – function or a registered method, callback after estimator train.
kwargs (Dict) – parameters for estimator training, Like: early_stopping_rounds in Xgboost.XGBClassifier
- Returns:
feedback (Dict) – contain all training result in each tasks.
task_index_url (str) – task extractor model path, used for task mining.
- load(task_index_url=None)[source]¶
load task_detail (tasks/models etc …) from task index file. It’ll automatically loaded during inference and evaluation phases.
- Parameters:
task_index_url (str) – task index file path, default self.task_index_url.
- predict(data: sedna.datasources.BaseDataSource, post_process=None, **kwargs)[source]¶
predict the result for input data based on training knowledge.
- Parameters:
data (BaseDataSource) – inference sample, see sedna.datasources.BaseDataSource for more detail.
post_process (function) – function or a registered method, effected after estimator prediction, like: label transform.
kwargs (Dict) – parameters for estimator predict, Like: ntree_limit in Xgboost.XGBClassifier
- Returns:
result (array_like) – results array, contain all inference results in each sample.
tasks (List) – tasks assigned to each sample.
- evaluate(data: sedna.datasources.BaseDataSource, metrics=None, metrics_param=None, **kwargs)[source]¶
evaluated the performance of each task from training, filter tasks based on the defined rules.
- Parameters:
data (BaseDataSource) – valid data, see sedna.datasources.BaseDataSource for more detail.
metrics (function / str) – Metrics to assess performance on the task by given prediction.
metrics_param (Dict) – parameter for metrics function.
kwargs (Dict) – parameters for estimator evaluate, Like: ntree_limit in Xgboost.XGBClassifier
- Returns:
task_eval_res (Dict) – all metric results.
tasks_detail (List[Object]) – all metric results in each task.