Source code for lib.sedna.core.lifelong_learning.knowledge_management.cloud_knowledge_management

# Copyright 2023 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import tempfile

from sedna.common.class_factory import ClassType, ClassFactory
from sedna.common.file_ops import FileOps
from sedna.common.config import Context
from sedna.common.constant import KBResourceConstant

from .base_knowledge_management import BaseKnowledgeManagement

__all__ = ('CloudKnowledgeManagement', )


@ClassFactory.register(ClassType.KM)
[docs]class CloudKnowledgeManagement(BaseKnowledgeManagement): """ Manage task processing, kb update and task deployment, etc., at cloud. """ def __init__(self, config, seen_estimator, unseen_estimator, **kwargs): super(CloudKnowledgeManagement, self).__init__( config, seen_estimator, unseen_estimator) self.last_task_index = kwargs.get("last_task_index", None) self.cloud_output_url = config.get( "cloud_output_url", "/tmp") self.task_index = FileOps.join_path( self.cloud_output_url, config["task_index"]) self.local_task_index_url = KBResourceConstant.KB_INDEX_NAME.value
[docs] def update_kb(self, task_index): if isinstance(task_index, str): task_index = FileOps.load(task_index) seen_task_index = task_index.get(self.seen_task_key) unseen_task_index = task_index.get(self.unseen_task_key) seen_extractor, seen_task_groups = self.save_task_index( seen_task_index, task_type=self.seen_task_key) unseen_extractor, unseen_task_groups = self.save_task_index( unseen_task_index, task_type=self.unseen_task_key) task_info = { self.seen_task_key: { self.task_group_key: seen_task_groups, self.extractor_key: seen_extractor }, self.unseen_task_key: { self.task_group_key: unseen_task_groups, self.extractor_key: unseen_extractor }, "create_time": str(time.time()) } fd, name = tempfile.mkstemp() FileOps.dump(task_info, name) index_file = self.kb_client.update_db(name) if not index_file: self.log.error("KB update Fail !") index_file = name return FileOps.upload(index_file, self.task_index)
[docs] def save_task_index(self, task_index, task_type="seen_task"): extractor = task_index[self.extractor_key] if isinstance(extractor, str): extractor = FileOps.load(extractor) task_groups = task_index[self.task_group_key] model_upload_key = {} for task_group in task_groups: model_file = task_group.model.model save_model = FileOps.join_path( self.cloud_output_url, task_type, os.path.basename(model_file) ) if model_file not in model_upload_key: model_upload_key[model_file] = FileOps.upload( model_file, save_model) model_file = model_upload_key[model_file] try: model = self.kb_client.upload_file(save_model) self.log.info( f"Upload task model to {model} successfully." ) except Exception as err: self.log.error( f"Upload task model of {model_file} fail: {err}" ) model = FileOps.join_path( self.cloud_output_url, task_type, os.path.basename(model_file)) task_group.model.model = model for _task in task_group.tasks: _task.model = FileOps.join_path( self.cloud_output_url, task_type, os.path.basename(model_file)) sample_dir = FileOps.join_path( self.cloud_output_url, task_type, f"{_task.samples.data_type}_{_task.entry}.sample") task_group.samples.save(sample_dir) try: sample_dir = self.kb_client.upload_file(sample_dir) self.log.info( f"Upload task sample to {sample_dir} successfully.") except Exception as err: self.log.error( f"Upload task samples of {_task.entry} fail: {err}") _task.samples.data_url = sample_dir save_extractor = FileOps.join_path( self.cloud_output_url, task_type, f"{task_type}_{KBResourceConstant.TASK_EXTRACTOR_NAME.value}" ) extractor = FileOps.dump(extractor, save_extractor) try: extractor = self.kb_client.upload_file(extractor) self.log.info( f"Upload task extractor to {extractor} successfully.") except Exception as err: self.log.error(f"Upload task extractor fail: {err}") return extractor, task_groups
[docs] def evaluate_tasks(self, tasks_detail, **kwargs): """ Parameters ---------- tasks_detail: List[Task] output of module task_update_decision, consisting of results of evaluation. Returns ------- drop_task: List[str] names of the tasks which will not to be deployed to the edge. """ self.model_filter_operator = Context.get_parameters("operator", ">") self.model_threshold = float( Context.get_parameters( "model_threshold", 0.1)) drop_tasks = [] operator_map = { ">": lambda x, y: x > y, "<": lambda x, y: x < y, "=": lambda x, y: x == y, ">=": lambda x, y: x >= y, "<=": lambda x, y: x <= y, } if self.model_filter_operator not in operator_map: self.log.warn( f"operator {self.model_filter_operator} use to " f"compare is not allow, set to <" ) self.model_filter_operator = "<" operator_func = operator_map[self.model_filter_operator] for detail in tasks_detail: scores = detail.scores entry = detail.entry self.log.info(f"{entry} scores: {scores}") if any(map(lambda x: operator_func(float(x), self.model_threshold), scores.values())): self.log.warn( f"{entry} will not be deploy because all " f"scores {self.model_filter_operator} " f"{self.model_threshold}") drop_tasks.append(entry) continue drop_task = ",".join(drop_tasks) return drop_task