lib.sedna.algorithms.multi_task_learning.multi_task_learning

Multiple task transfer learning algorithms

Module Contents

Classes

MulTaskLearning

An auto machine learning framework for edge-cloud multitask learning

class lib.sedna.algorithms.multi_task_learning.multi_task_learning.MulTaskLearning(estimator=None, task_definition=None, task_relationship_discovery=None, task_mining=None, task_remodeling=None, inference_integrate=None)[source]

An auto machine learning framework for edge-cloud multitask learning

See also

Train

Data + Estimator -> Task Definition -> Task Relationship Discovery -> Feature Engineering -> Training

Inference

Data -> Task Allocation -> Task Mining -> Feature Engineering -> Task Remodeling -> Inference

Parameters:
  • estimator (Instance) – An instance with the high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model.

  • task_definition (Dict) – Divide multiple tasks based on data, see task_jobs.task_definition for more detail.

  • task_relationship_discovery (Dict) – Discover relationships between all tasks, see task_jobs.task_relationship_discovery for more detail.

  • task_mining (Dict) – Mining tasks of inference sample, see task_jobs.task_mining for more detail.

  • task_remodeling (Dict) – Remodeling tasks based on their relationships, see task_jobs.task_remodeling for more detail.

  • inference_integrate (Dict) – Integrate the inference results of all related tasks, see task_jobs.inference_integrate for more detail.

Examples

>>> from xgboost import XGBClassifier
>>> from sedna.algorithms.multi_task_learning import MulTaskLearning
>>> estimator = XGBClassifier(objective="binary:logistic")
>>> task_definition = {
        "method": "TaskDefinitionByDataAttr",
        "param": {"attribute": ["season", "city"]}
    }
>>> task_relationship_discovery = {
        "method": "DefaultTaskRelationDiscover", "param": {}
    }
>>> task_mining = {
        "method": "TaskMiningByDataAttr",
        "param": {"attribute": ["season", "city"]}
    }
>>> task_remodeling = None
>>> inference_integrate = {
        "method": "DefaultInferenceIntegrate", "param": {}
    }
>>> mul_task_instance = MulTaskLearning(
        estimator=estimator,
        task_definition=task_definition,
        task_relationship_discovery=task_relationship_discovery,
        task_mining=task_mining,
        task_remodeling=task_remodeling,
        inference_integrate=inference_integrate
    )

Notes

All method defined under task_jobs and registered in ClassFactory.

train(train_data: sedna.datasources.BaseDataSource, valid_data: sedna.datasources.BaseDataSource = None, post_process=None, **kwargs)[source]

fit for update the knowledge based on training data.

Parameters:
  • train_data (BaseDataSource) – Train data, see sedna.datasources.BaseDataSource for more detail.

  • valid_data (BaseDataSource) – Valid data, BaseDataSource or None.

  • post_process (function) – function or a registered method, callback after estimator train.

  • kwargs (Dict) – parameters for estimator training, Like: early_stopping_rounds in Xgboost.XGBClassifier

Returns:

  • feedback (Dict) – contain all training result in each tasks.

  • task_index_url (str) – task extractor model path, used for task mining.

load(task_index_url=None)[source]

load task_detail (tasks/models etc …) from task index file. It’ll automatically loaded during inference and evaluation phases.

Parameters:

task_index_url (str) – task index file path, default self.task_index_url.

predict(data: sedna.datasources.BaseDataSource, post_process=None, **kwargs)[source]

predict the result for input data based on training knowledge.

Parameters:
  • data (BaseDataSource) – inference sample, see sedna.datasources.BaseDataSource for more detail.

  • post_process (function) – function or a registered method, effected after estimator prediction, like: label transform.

  • kwargs (Dict) – parameters for estimator predict, Like: ntree_limit in Xgboost.XGBClassifier

Returns:

  • result (array_like) – results array, contain all inference results in each sample.

  • tasks (List) – tasks assigned to each sample.

evaluate(data: sedna.datasources.BaseDataSource, metrics=None, metrics_param=None, **kwargs)[source]

evaluated the performance of each task from training, filter tasks based on the defined rules.

Parameters:
  • data (BaseDataSource) – valid data, see sedna.datasources.BaseDataSource for more detail.

  • metrics (function / str) – Metrics to assess performance on the task by given prediction.

  • metrics_param (Dict) – parameter for metrics function.

  • kwargs (Dict) – parameters for estimator evaluate, Like: ntree_limit in Xgboost.XGBClassifier

Returns:

  • task_eval_res (Dict) – all metric results.

  • tasks_detail (List[Object]) – all metric results in each task.