Skip to content

Transformation

TopLabels

Bases: TransformationBase

Picks the probabilities with a threshold higher than min_threshold

Source code in src/annotation/transformation.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class TopLabels(TransformationBase):
    """
    Picks the probabilities with a threshold higher than min_threshold
    """

    def __init__(self, min_threshold: float = 0.05):
        self.min_threshold = min_threshold

    def transform(self, distribution: np.array) -> np.array:
        res_distribution = distribution > self.min_threshold

        norm = np.linalg.norm(res_distribution)
        if norm != 0:
            res_distribution = res_distribution / norm
        return res_distribution