Source code for lib.sedna.core.lifelong_learning.knowledge_management.edge_knowledge_management

# Copyright 2023 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import tempfile
import threading

from watchdog.observers import Observer
from watchdog.events import *
from sedna.common.log import LOGGER
from sedna.common.config import Context, BaseConfig
from sedna.common.class_factory import ClassType, ClassFactory
from sedna.common.file_ops import FileOps
from sedna.common.constant import KBResourceConstant, K8sResourceKindStatus

from .base_knowledge_management import BaseKnowledgeManagement

__all__ = ('EdgeKnowledgeManagement', )


@ClassFactory.register(ClassType.KM)
[docs]class EdgeKnowledgeManagement(BaseKnowledgeManagement): """ Manage inference, knowledge base update, etc., at the edge. """ def __init__(self, config, seen_estimator, unseen_estimator, **kwargs): super(EdgeKnowledgeManagement, self).__init__( config, seen_estimator, unseen_estimator) self.edge_output_url = Context.get_parameters( "edge_output_url", KBResourceConstant.EDGE_KB_DIR.value) self.task_index = FileOps.join_path( self.edge_output_url, KBResourceConstant.KB_INDEX_NAME.value) self.local_unseen_save_url = FileOps.join_path( self.edge_output_url, "unseen_samples") os.makedirs(self.local_unseen_save_url, exist_ok=True) self.pinned_service_start = False self.unseen_sample_observer = None self.current_index_version = None self.lastest_index_version = None
[docs] def update_kb(self, task_index): if isinstance(task_index, str): try: task_index = FileOps.load(task_index) except Exception as err: self.log.error(f"{err}") self.log.error( "Load task index failed. " "KB deployment to the edge failed.") return None seen_task_index = task_index.get(self.seen_task_key) unseen_task_index = task_index.get(self.unseen_task_key) seen_extractor, seen_task_groups = self.save_task_index( seen_task_index, task_type=self.seen_task_key) unseen_extractor, unseen_task_groups = self.save_task_index( unseen_task_index, task_type=self.unseen_task_key) task_info = { self.seen_task_key: { self.task_group_key: seen_task_groups, self.extractor_key: seen_extractor }, self.unseen_task_key: { self.task_group_key: unseen_task_groups, self.extractor_key: unseen_extractor }, "create_time": task_index.get("create_time", str(time.time())) } self.current_index_version = str(task_info.get("create_time")) self.lastest_index_version = self.current_index_version fd, name = tempfile.mkstemp() FileOps.dump(task_info, name) return FileOps.upload(name, self.task_index)
[docs] def save_task_index(self, task_index, task_type="seen_task"): extractor = task_index[self.extractor_key] if isinstance(extractor, str): extractor = FileOps.load(extractor) task_groups = task_index[self.task_group_key] model_upload_key = {} for task in task_groups: model_file = task.model.model save_model = FileOps.join_path( self.edge_output_url, task_type, os.path.basename(model_file) ) if model_file not in model_upload_key: model_upload_key[model_file] = FileOps.download( model_file, save_model) model_file = model_upload_key[model_file] task.model.model = save_model for _task in task.tasks: _task.model = FileOps.join_path( self.edge_output_url, task_type, os.path.basename(model_file)) sample_dir = FileOps.join_path( self.edge_output_url, task_type, f"{_task.samples.data_type}_{_task.entry}.sample") _task.samples.data_url = FileOps.download( _task.samples.data_url, sample_dir) self.log.info(f"Download {_task.entry} to the edge.") save_extractor = FileOps.join_path( self.edge_output_url, task_type, KBResourceConstant.TASK_EXTRACTOR_NAME.value ) extractor = FileOps.dump(extractor, save_extractor) return extractor, task_groups
[docs] def save_unseen_samples(self, samples, post_process): for sample in samples.x: if callable(post_process): # customized sample saving function post_process(sample, self.local_unseen_save_url) continue if isinstance(sample, dict): img = sample.get("image") image_name = "{}.png".format(str(time.time())) image_url = FileOps.join_path( self.local_unseen_save_url, image_name) img.save(image_url) else: image_name = os.path.basename(sample[0]) image_url = FileOps.join_path( self.local_unseen_save_url, image_name) FileOps.upload(sample[0], image_url, clean=False) LOGGER.info(f"Unseen sample uploading completes.")
[docs] def start_services(self): self.unseen_sample_observer = Observer() self.unseen_sample_observer.schedule( UnseenSampleUploadingHandler(), self.local_unseen_save_url, True) self.unseen_sample_observer.start() ModelHotUpdateThread(self).start()
class ModelHotUpdateThread(threading.Thread): """Hot task index loading with multithread support""" MODEL_MANIPULATION_SEM = threading.Semaphore(1) def __init__(self, edge_knowledge_management, callback=None ): model_check_time = int(Context.get_parameters( "MODEL_POLL_PERIOD_SECONDS", "30") ) if model_check_time < 1: LOGGER.warning("Catch an abnormal value in " "`MODEL_POLL_PERIOD_SECONDS`, fallback with 30") model_check_time = 30 self.edge_knowledge_management = edge_knowledge_management self.check_time = model_check_time self.callback = callback super(ModelHotUpdateThread, self).__init__() LOGGER.info(f"Model hot update service starts.") def run(self): while True: time.sleep(self.check_time) if not self.edge_knowledge_management.current_index_version: continue latest_task_index = Context.get_parameters("MODEL_URLS") if not latest_task_index: continue latest_task_index = FileOps.load(latest_task_index) self.edge_knowledge_management.lastest_index_version = str( latest_task_index.get("create_time")) class UnseenSampleUploadingHandler(FileSystemEventHandler): def __init__(self): FileSystemEventHandler.__init__(self) self.unseen_save_url = Context.get_parameters( "unseen_save_url", os.path.join( BaseConfig.data_path_prefix, "unseen_samples")) if not FileOps.is_remote(self.unseen_save_url): os.makedirs(self.unseen_save_url, exist_ok=True) LOGGER.info(f"Unseen sample uploading service starts.") def on_created(self, event): time.sleep(1.0) sample_name = os.path.basename(event.src_path) FileOps.upload(event.src_path, FileOps.join_path( self.unseen_save_url, sample_name))