Source code for lib.sedna.algorithms.hard_example_mining.hard_example_mining

# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hard Example Mining Algorithms"""

import abc
import math
import random
from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')


class BaseFilter(metaclass=abc.ABCMeta):
    """The base class to define unified interface."""

    def __call__(self, infer_result=None):
        """
        predict function, judge the sample is hard or not.

        Parameters
        ----------
        infer_result : array_like
            prediction result

        Returns
        -------
        is_hard_sample : bool
            `True` means hard sample, `False` means not.
        """
        raise NotImplementedError

    @classmethod
    def data_check(cls, data):
        """Check the data in [0,1]."""
        return 0 <= float(data) <= 1


@ClassFactory.register(ClassType.HEM, alias="Threshold")
[docs]class ThresholdFilter(BaseFilter, abc.ABC): """ **Object detection** Hard samples discovery methods named `Threshold` Parameters ---------- threshold: float hard coefficient threshold score to filter img, default to 0.5. """ def __init__(self, threshold: float = 0.5, **kwargs): self.threshold = float(threshold)
[docs] def __call__(self, infer_result=None) -> bool: # if invalid input, return False if not (infer_result and all(map(lambda x: len(x) > 4, infer_result))): return False image_score = 0 for bbox in infer_result: image_score += bbox[4] average_score = image_score / (len(infer_result) or 1) return average_score < self.threshold
@ClassFactory.register(ClassType.HEM, alias="CrossEntropy")
[docs]class CrossEntropyFilter(BaseFilter, abc.ABC): """ **Object detection** Hard samples discovery methods named `CrossEntropy` Parameters ---------- threshold_cross_entropy: float hard coefficient threshold score to filter img, default to 0.5. """ def __init__(self, threshold_cross_entropy=0.5, **kwargs): self.threshold_cross_entropy = float(threshold_cross_entropy)
[docs] def __call__(self, infer_result=None) -> bool: """judge the img is hard sample or not. Parameters ---------- infer_result: array_like prediction classes list, such as [class1-score, class2-score, class2-score,....], where class-score is the score corresponding to the class, class-score value is in [0,1], who will be ignored if its value not in [0,1]. Returns ------- is hard sample: bool `True` means hard sample, `False` means not. """ if not infer_result: # if invalid input, return False return False log_sum = 0.0 data_check_list = [class_probability for class_probability in infer_result if self.data_check(class_probability)] if len(data_check_list) != len(infer_result): return False for class_data in data_check_list: log_sum += class_data * math.log(class_data) confidence_score = 1 + 1.0 * log_sum / math.log( len(infer_result)) return confidence_score < self.threshold_cross_entropy
@ClassFactory.register(ClassType.HEM, alias="IBT")
[docs]class IBTFilter(BaseFilter, abc.ABC): """ **Object detection** Hard samples discovery methods named `IBT` Parameters ---------- threshold_img: float hard coefficient threshold score to filter img, default to 0.5. threshold_box: float threshold_box to calculate hard coefficient, formula is hard coefficient = number(prediction_boxes less than threshold_box) / number(prediction_boxes) """ def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): self.threshold_box = float(threshold_box) self.threshold_img = float(threshold_img)
[docs] def __call__(self, infer_result=None) -> bool: """Judge the img is hard sample or not. Parameters ---------- infer_result: array_like prediction boxes list, such as [bbox1, bbox2, bbox3,....], where bbox = [xmin, ymin, xmax, ymax, score, label] score should be in [0,1], who will be ignored if its value not in [0,1]. Returns ------- is hard sample: bool `True` means hard sample, `False` means not. """ if not (infer_result and all(map(lambda x: len(x) > 4, infer_result))): # if invalid input, return False return False data_check_list = [bbox[4] for bbox in infer_result if self.data_check(bbox[4])] if len(data_check_list) != len(infer_result): return False confidence_score_list = [ float(box_score) for box_score in data_check_list if float(box_score) <= self.threshold_box] return (len(confidence_score_list) / len(infer_result) >= (1 - self.threshold_img))
@ClassFactory.register(ClassType.HEM, alias="Random") class RandomFilter(BaseFilter): """judge a image is hard example or not randomly Parameters ---------- random_ratio: int value: between 0 and 1 with a model having very high accuracy like 98%, use this function to define an input is hard example or not. just a meaningless but needed function in sedna incremental learning inference Returns ------- is hard sample: bool `True` means hard sample, `False` means not. """ def __init__(self, random_ratio=0.3, **kwargs): self.random_ratio = random_ratio def __call__(self, *args, **kwargs): if random.uniform(0, 1) < self.random_ratio: return True return False