Focal Loss 安装与使用 TensorFlow2.x版本
Posted 一颗小树x
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Focal Loss 安装与使用 TensorFlow2.x版本相关的知识,希望对你有一定的参考价值。
前言
本文介绍在TensorFlow2.x中,如何简便地使用 Focal Loss 损失函数;它可以通过 pip 来安装的;调用也比较方便。
一、安装
方式1:直接通过pip安装
pip install focal-loss
当前版本:focal-loss 0.0.7
支持的python版本:python3.6、python3.7、python3.9
方式2:源码安装
源码地址:https://github.com/artemmavrin/focal-loss
我选择目前最新的0.0.7,然后下载源代码:focal-loss-master.zip
这里先看一下,安装的脚步代码:setup.py,看看它依赖那些库;
"""Setup script."""
import os
import pathlib
import re
from setuptools import setup, find_packages
# Set the environment variable TF_CPU (to anything) to use tensorflow-cpu
_TENSORFLOW_CPU = os.environ.get('TF_CPU', None)
# TensorFlow package name and version
_TENSORFLOW = 'tensorflow' if _TENSORFLOW_CPU is None else 'tensorflow-cpu'
_MIN_TENSORFLOW_VERSION = '2.2'
_TENSORFLOW += f'>=_MIN_TENSORFLOW_VERSION'
# Directory of this setup.py file
_HERE = pathlib.Path(__file__).parent
def _resolve_path(*parts):
"""Get a filename from a list of path components, relative to this file."""
return _HERE.joinpath(*parts).absolute()
def _read(*parts):
"""Read a file's contents into a string."""
filename = _resolve_path(*parts)
return filename.read_text()
__INIT__ = _read('src', 'focal_loss', '__init__.py')
def _get_package_variable(name):
pattern = rf'^name = [\\'"](?P<value>[^\\'"]*)[\\'"]'
match = re.search(pattern, __INIT__, flags=re.M)
if match:
return match.group('value')
raise RuntimeError(f'Cannot find variable name')
setup(
name=_get_package_variable('__package__'),
version=_get_package_variable('__version__'),
description=_get_package_variable('__description__'),
url=_get_package_variable('__url__'),
author=_get_package_variable('__author__'),
author_email=_get_package_variable('__author_email__'),
long_description=_read('README.rst'),
long_description_content_type='text/x-rst',
packages=find_packages('src', exclude=['*.tests']),
package_dir='': 'src',
license='Apache 2.0',
classifiers=[
'Development Status :: 1 - Planning',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
install_requires=[
_TENSORFLOW,
],
extras_require=
# The 'dev' extra is for development, including running tests and
# generating documentation
'dev': [
'numpy',
'scipy',
'matplotlib',
'seaborn',
'pytest',
'coverage',
'sphinx',
'sphinx_rtd_theme',
],
,
zip_safe=False,
)
能看到它需要 tensorflow2.2 以上的版本;支持python3.6、python3.7、python3.9;
依赖这些库
'numpy',
'scipy',
'matplotlib',
'seaborn',
'pytest',
'coverage',
'sphinx',
'sphinx_rtd_theme',
如果已经安装了tensorflow2.2 以上的版本,建议注释tensorflow、numpy、scipy这些库;这些它在安装focal-loss时,会跳过检测是否存在这些库,和版本是否符合。
然后执行进行tensorflow的开发环境,进行focal-loss-master的目录,执行如下代码,进行源码安装:
python setup.py install
等待安装完成~
二、使用Focal Loss
使用说明:
-
二分类的Focal Loss :BinaryFocalLoss (use like tf.keras.losses.BinaryCrossentropy)
-
多分类的Focal Loss:SparseCategoricalFocalLoss (use like tf.keras.losses.SparseCategoricalCrossentropy)
示例1:二分类的Focal Loss
# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
示例2:多分类的Focal Loss
# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import SparseCategoricalFocalLoss
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
关于它的参数,如下图所示:
三、看看源码
二分类的:
"""Binary focal loss implementation."""
# ____ __ ___ __ __ __ __ ____ ____
# ( __)/ \\ / __) / _\\ ( ) ( ) / \\ / ___)/ ___)
# ) _)( O )( (__ / \\/ (_/\\ / (_/\\( O )\\___ \\\\___ \\
# (__) \\__/ \\___)\\_/\\_/\\____/ \\____/ \\__/ (____/(____/
from functools import partial
import tensorflow as tf
from .utils.validation import check_bool, check_float
_EPSILON = tf.keras.backend.epsilon()
def binary_focal_loss(y_true, y_pred, gamma, *, pos_weight=None,
from_logits=False, label_smoothing=None):
r"""Focal loss function for binary classification.
This loss function generalizes binary cross-entropy by introducing a
hyperparameter :math:`\\gamma` (gamma), called the *focusing parameter*,
that allows hard-to-classify examples to be penalized more heavily relative
to easy-to-classify examples.
The focal loss [1]_ is defined as
.. math::
L(y, \\hatp)
= -\\alpha y \\left(1 - \\hatp\\right)^\\gamma \\log(\\hatp)
- (1 - y) \\hatp^\\gamma \\log(1 - \\hatp)
where
* :math:`y \\in \\0, 1\\` is a binary class label,
* :math:`\\hatp \\in [0, 1]` is an estimate of the probability of the
positive class,
* :math:`\\gamma` is the *focusing parameter* that specifies how much
higher-confidence correct predictions contribute to the overall loss
(the higher the :math:`\\gamma`, the higher the rate at which
easy-to-classify examples are down-weighted).
* :math:`\\alpha` is a hyperparameter that governs the trade-off between
precision and recall by weighting errors for the positive class up or
down (:math:`\\alpha=1` is the default, which is the same as no
weighting),
The usual weighted binary cross-entropy loss is recovered by setting
:math:`\\gamma = 0`.
Parameters
----------
y_true : tensor-like
Binary (0 or 1) class labels.
y_pred : tensor-like
Either probabilities for the positive class or logits for the positive
class, depending on the `from_logits` parameter. The shapes of `y_true`
and `y_pred` should be broadcastable.
gamma : float
The focusing parameter :math:`\\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative.
pos_weight : float, optional
The coefficient :math:`\\alpha` to use on the positive examples. Must be
non-negative.
from_logits : bool, optional
Whether `y_pred` contains logits or probabilities.
label_smoothing : float, optional
Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
ground truth labels `y_true` are squeezed toward 0.5, with larger values
of `label_smoothing` leading to label values closer to 0.5.
Returns
-------
:class:`tf.Tensor`
The focal loss for each example (assuming `y_true` and `y_pred` have the
same shapes). In general, the shape of the output is the result of
broadcasting the shapes of `y_true` and `y_pred`.
Warnings
--------
This function does not reduce its output to a scalar, so it cannot be passed
to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the
wrapper class :class:`~focal_loss.BinaryFocalLoss`.
Examples
--------
This function computes the per-example focal loss between a label and
prediction tensor:
>>> import numpy as np
>>> from focal_loss import binary_focal_loss
>>> loss = binary_focal_loss([0, 1, 1], [0.1, 0.7, 0.9], gamma=2)
>>> np.set_printoptions(precision=3)
>>> print(loss.numpy())
[0.001 0.032 0.001]
Below is a visualization of the focal loss between the positive class and
predicted probabilities between 0 and 1. Note that as :math:`\\gamma`
increases, the losses for predictions closer to 1 get smoothly pushed to 0.
.. plot::
:include-source:
:align: center
import numpy as np
import matplotlib.pyplot as plt
from focal_loss import binary_focal_loss
ps = np.linspace(0, 1, 100)
gammas = (0, 0.5, 1, 2, 5)
plt.figure()
for gamma in gammas:
loss = binary_focal_loss(1, ps, gamma=gamma)
label = rf'$\\gamma$=gamma'
if gamma == 0:
label += ' (cross-entropy)'
plt.plot(ps, loss, label=label)
plt.legend(loc='best', frameon=True, shadow=True)
plt.xlim(0, 1)
plt.ylim(0, 4)
plt.xlabel(r'Probability of positive class $\\hatp$')
plt.ylabel('Loss')
plt.title(r'Plot of focal loss $L(1, \\hatp)$ for different $\\gamma$',
fontsize=14)
plt.show()
Notes
-----
A classifier often estimates the positive class probability :math:`\\hatp`
by computing a real-valued *logit* :math:`\\haty \\in \\mathbbR` and
applying the *sigmoid function* :math:`\\sigma : \\mathbbR \\to (0, 1)`
defined by
.. math::
\\sigma(t) = \\frac11 + e^-t, \\qquad (t \\in \\mathbbR).
That is, :math:`\\hatp = \\sigma(\\haty)`. In this case, the focal loss
can be written as a function of the logit :math:`\\haty` instead of the
predicted probability :math:`\\hatp`:
.. math::
L(y, \\haty)
= -\\alpha y \\left(1 - \\sigma(\\haty)\\right)^\\gamma
\\log(\\sigma(\\haty))
- (1 - y) \\sigma(\\haty)^\\gamma \\log(1 - \\sigma(\\haty)).
This is the formula that is computed when specifying `from_logits=True`.
However, this formula is not very numerically stable if implemented
directly; for example, there are multiple log and sigmoid computations
involved. Instead, we use some tricks to rewrite it in the more numerically
stable form
.. math::
L(y, \\haty)
= (1 - y) \\hatp^\\gamma \\haty
+ \\left(\\alpha y \\hatq^\\gamma + (1 - y) \\hatp^\\gamma\\right)
\\left(\\log(1 + e^-|\\haty|) + \\max\\-\\haty, 0\\\\right),
where :math:`\\hatp = \\sigma(\\haty)` and :math:`\\hatq = 1 - \\hatp`
denote the estimates of the probabilities of the positive and negative
classes, respectively.
Indeed, starting with the observations that
.. math::
\\log(\\sigma(\\haty))
= \\log\\left(\\frac11 + e^-\\haty\\right)
= -\\log(1 + e^-\\haty)
and
.. math::
\\log(1 - \\sigma(\\haty))
= \\log\\left(\\frace^-\\haty1 + e^-\\haty\\right)
= -\\haty - \\log(1 + e^-\\haty),
we obtain
.. math::
\\beginaligned
L(y, \\haty)
&= -\\alpha y \\hatq^\\gamma \\log(\\sigma(\\haty))
- (1 - y) \\hatp^\\gamma \\log(1 - \\sigma(\\haty)) \\\\
&= \\alpha y \\hatq^\\gamma \\log(1 + e^-\\haty)
+ (1 - y) \\hatp^\\gamma \\left(\\haty + \\log(1 + e^-\\haty)\\right)\\\\
&= (1 - y) \\hatp^\\gamma \\haty
+ \\left(\\alpha y \\hatq^\\gamma + (1 - y) \\hatp^\\gamma\\right)
\\log(1 + e^-\\haty).
\\endaligned
Note that if :math:`\\haty < 0`, then the exponential term
:math:`e^-\\haty` could become very large. In this case, we can instead
observe that
.. math::
\\beginalign*
\\log(1 + e^-\\haty)
&= \\log(1 + e^-\\haty) + \\haty - \\haty \\\\
&= \\log(1 + e^-\\haty) + \\log(e^\\haty) - \\haty \\\\
&= \\log(1 + e^\\haty) - \\haty.
\\endalign*
Moreover, the :math:`\\haty < 0` and :math:`\\haty \\geq 0` cases can be
unified by writing
.. math::
\\log(1 + e^-\\haty)
= \\log(1 + e^-|\\haty|) + \\max\\-\\haty, 0\\.
Thus, we arrive at the numerically stable formula shown earlier.
References
----------
.. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for
dense object detection. IEEE Transactions on Pattern Analysis and
Machine Intelligence, 2018.
(`DOI <https://doi.org/10.1109/TPAMI.2018.2858826>`__)
(`arXiv preprint <https://arxiv.org/abs/1708.02002>`__)
See Also
--------
:meth:`~focal_loss.BinaryFocalLoss`
A wrapper around this function that makes it a
:class:`tf.keras.losses.Loss`.
"""
# Validate arguments
gamma = check_float(gamma, name='gamma', minimum=0)
pos_weight = check_float(pos_weight, name='pos_weight', minimum=0,
allow_none=True)
from_logits = check_bool(from_logits, name='from_logits')
label_smoothing = check_float(label_smoothing, name='label_smoothing',
minimum=0, maximum=1, allow_none=True)
# Ensure predictions are a floating point tensor; converting labels to a
# tensor will be done in the helper functions
y_pred = tf.convert_to_tensor(y_pred)
if not y_pred.dtype.is_floating:
y_pred = tf.dtypes.cast(y_pred, dtype=tf.float32)
# Delegate per-example loss computation to helpers depending on whether
# predictions are logits or probabilities
if from_logits:
return _binary_focal_loss_from_logits(labels=y_true, logits=y_pred,
gamma=gamma,
pos_weight=pos_weight,
label_smoothing=label_smoothing)
else:
return _binary_focal_loss_from_probs(labels=y_true, p=y_pred,
gamma=gamma, pos_weight=pos_weight,
label_smoothing=label_smoothing)
@tf.keras.utils.register_keras_serializable()
class BinaryFocalLoss(tf.keras.losses.Loss):
r"""Focal loss function for binary classification.
This loss function generalizes binary cross-entropy by introducing a
hyperparameter called the *focusing parameter* that allows hard-to-classify
examples to be penalized more heavily relative to easy-to-classify examples.
This class is a wrapper around :class:`~focal_loss.binary_focal_loss`. See
the documentation there for details about this loss function.
Parameters
----------
gamma : float
The focusing parameter :math:`\\gamma`. Must be non-negative.
pos_weight : float, optional
The coefficient :math:`\\alpha` to use on the positive examples. Must be
non-negative.
from_logits : bool, optional
Whether model prediction will be logits or probabilities.
label_smoothing : float, optional
Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
ground truth labels are squeezed toward 0.5, with larger values of
`label_smoothing` leading to label values closer to 0.5.
**kwargs : keyword arguments
Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name`
or `reduction`).
Examples
--------
An instance of this class is a callable that takes a tensor of binary ground
truth labels `y_true` and a tensor of model predictions `y_pred` and returns
a scalar tensor obtained by reducing the per-example focal loss (the default
reduction is a batch-wise average).
>>> from focal_loss import BinaryFocalLoss
>>> loss_func = BinaryFocalLoss(gamma=2)
>>> loss = loss_func([0, 1, 1], [0.1, 0.7, 0.9]) # A scalar tensor
>>> print(f'Mean focal loss: loss.numpy():.3f')
Mean focal loss: 0.011
Use this class in the :mod:`tf.keras` API like any other binary
classification loss function class found in :mod:`tf.keras.losses` (e.g.,
:class:`tf.keras.losses.BinaryCrossentropy`:
.. code-block:: python
# Typical usage
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
See Also
--------
:meth:`~focal_loss.binary_focal_loss`
The function that performs the focal loss computation, taking a label
tensor and a prediction tensor and outputting a loss.
"""
def __init__(self, gamma, *, pos_weight=None, from_logits=False,
label_smoothing=None, **kwargs):
# Validate arguments
gamma = check_float(gamma, name='gamma', minimum=0)
pos_weight = check_float(pos_weight, name='pos_weight', minimum=0,
allow_none=True)
from_logits = check_bool(from_logits, name='from_logits')
label_smoothing = check_float(label_smoothing, name='label_smoothing',
minimum=0, maximum=1, allow_none=True)
super().__init__(**kwargs)
self.gamma = gamma
self.pos_weight = pos_weight
self.from_logits = from_logits
self.label_smoothing = label_smoothing
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary containing the configuration of a
layer. The same layer can be re-instantiated later (without its trained
weights) from this configuration.
Returns
-------
dict
This layer's config.
"""
config = super().get_config()
config.update(gamma=self.gamma, pos_weight=self.pos_weight,
from_logits=self.from_logits,
label_smoothing=self.label_smoothing)
return config
def call(self, y_true, y_pred):
"""Compute the per-example focal loss.
This method simply calls :meth:`~focal_loss.binary_focal_loss` with the
appropriate arguments.
Parameters
----------
y_true : tensor-like
Binary (0 or 1) class labels.
y_pred : tensor-like
Either probabilities for the positive class or logits for the
positive class, depending on the `from_logits` attribute. The shapes
of `y_true` and `y_pred` should be broadcastable.
Returns
-------
:class:`tf.Tensor`
The per-example focal loss. Reduction to a scalar is handled by
this layer's :meth:`~focal_loss.BinaryFocalLoss.__call__` method.
"""
return binary_focal_loss(y_true=y_true, y_pred=y_pred, gamma=self.gamma,
pos_weight=self.pos_weight,
from_logits=self.from_logits,
label_smoothing=self.label_smoothing)
# Helper functions below
def _process_labels(labels, label_smoothing, dtype):
"""Pre-process a binary label tensor, maybe applying smoothing.
Parameters
----------
labels : tensor-like
Tensor of 0's and 1's.
label_smoothing : float or None
Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
ground truth labels `y_true` are squeezed toward 0.5, with larger values
of `label_smoothing` leading to label values closer to 0.5.
dtype : tf.dtypes.DType
Desired type of the elements of `labels`.
Returns
-------
tf.Tensor
The processed labels.
"""
labels = tf.dtypes.cast(labels, dtype=dtype)
if label_smoothing is not None:
labels = (1 - label_smoothing) * labels + label_smoothing * 0.5
return labels
def _binary_focal_loss_from_logits(labels, logits, gamma, pos_weight,
label_smoothing):
"""Compute focal loss from logits using a numerically stable formula.
Parameters
----------
labels : tensor-like
Tensor of 0's and 1's: binary class labels.
logits : tf.Tensor
Logits for the positive class.
gamma : float
Focusing parameter.
pos_weight : float or None
If not None, losses for the positive class will be scaled by this
weight.
label_smoothing : float or None
Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
ground truth labels `y_true` are squeezed toward 0.5, with larger values
of `label_smoothing` leading to label values closer to 0.5.
Returns
-------
tf.Tensor
The loss for each example.
"""
labels = _process_labels(labels=labels, label_smoothing=label_smoothing,
dtype=logits.dtype)
# Compute probabilities for the positive class
p = tf.math.sigmoid(logits)
# Without label smoothing we can use TensorFlow's built-in per-example cross
# entropy loss functions and multiply the result by the modulating factor.
# Otherwise, we compute the focal loss ourselves using a numerically stable
# formula below
if label_smoothing is None:
# The labels and logits tensors' shapes need to be the same for the
# built-in cross-entropy functions. Since we want to allow broadcasting,
# we do some checks on the shapes and possibly broadcast explicitly
# Note: tensor.shape returns a tf.TensorShape, whereas tf.shape(tensor)
# returns an int tf.Tensor; this is why both are used below
labels_shape = labels.shape
logits_shape = logits.shape
if not labels_shape.is_fully_defined() or labels_shape != logits_shape:
labels_shape = tf.shape(labels)
logits_shape = tf.shape(logits)
shape = tf.broadcast_dynamic_shape(labels_shape, logits_shape)
labels = tf.broadcast_to(labels, shape)
logits = tf.broadcast_to(logits, shape)
if pos_weight is None:
loss_func = tf.nn.sigmoid_cross_entropy_with_logits
else:
loss_func = partial(tf.nn.weighted_cross_entropy_with_logits,
pos_weight=pos_weight)
loss = loss_func(labels=labels, logits=logits)
modulation_pos = (1 - p) ** gamma
modulation_neg = p ** gamma
mask = tf.dtypes.cast(labels, dtype=tf.bool)
modulation = tf.where(mask, modulation_pos, modulation_neg)
return modulation * loss
# Terms for the positive and negative class components of the loss
pos_term = labels * ((1 - p) ** gamma)
neg_term = (1 - labels) * (p ** gamma)
# Term involving the log and ReLU
log_weight = pos_term
if pos_weight is not None:
log_weight *= pos_weight
log_weight += neg_term
log_term = tf.math.log1p(tf.math.exp(-tf.math.abs(logits)))
log_term += tf.nn.relu(-logits)
log_term *= log_weight
# Combine all the terms into the loss
loss = neg_term * logits + log_term
return loss
def _binary_focal_loss_from_probs(labels, p, gamma, pos_weight,
label_smoothing):
"""Compute focal loss from probabilities.
Parameters
----------
labels : tensor-like
Tensor of 0's and 1's: binary class labels.
p : tf.Tensor
Estimated probabilities for the positive class.
gamma : float
Focusing parameter.
pos_weight : float or None
If not None, losses for the positive class will be scaled by this
weight.
label_smoothing : float or None
Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary
ground truth labels `y_true` are squeezed toward 0.5, with larger values
of `label_smoothing` leading to label values closer to 0.5.
Returns
-------
tf.Tensor
The loss for each example.
"""
# Predicted probabilities for the negative class
q = 1 - p
# For numerical stability (so we don't inadvertently take the log of 0)
p = tf.math.maximum(p, _EPSILON)
q = tf.math.maximum(q, _EPSILON)
# Loss for the positive examples
pos_loss = -(q ** gamma) * tf.math.log(p)
if pos_weight is not None:
pos_loss *= pos_weight
# Loss for the negative examples
neg_loss = -(p ** gamma) * tf.math.log(q)
# Combine loss terms
if label_smoothing is None:
labels = tf.dtypes.cast(labels, dtype=tf.bool)
loss = tf.where(labels, pos_loss, neg_loss)
else:
labels = _process_labels(labels=labels, label_smoothing=label_smoothing,
dtype=p.dtype)
loss = labels * pos_loss + (1 - labels) * neg_loss
return loss
多分类的:
"""Multiclass focal loss implementation."""
# __ _ _
# / _| | | | |
# | |_ ___ ___ __ _ | | | | ___ ___ ___
# | _| / _ \\ / __| / _` | | | | | / _ \\ / __| / __|
# | | | (_) | | (__ | (_| | | | | | | (_) | \\__ \\ \\__ \\
# |_| \\___/ \\___| \\__,_| |_| |_| \\___/ |___/ |___/
import itertools
from typing import Any, Optional
import tensorflow as tf
_EPSILON = tf.keras.backend.epsilon()
def sparse_categorical_focal_loss(y_true, y_pred, gamma, *,
class_weight: Optional[Any] = None,
from_logits: bool = False, axis: int = -1
) -> tf.Tensor:
r"""Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by
introducing a hyperparameter called the *focusing parameter* that allows
hard-to-classify examples to be penalized more heavily relative to
easy-to-classify examples.
See :meth:`~focal_loss.binary_focal_loss` for a description of the focal
loss in the binary setting, as presented in the original work [1]_.
In the multiclass setting, with integer labels :math:`y`, focal loss is
defined as
.. math::
L(y, \\hat\\mathbfp)
= -\\left(1 - \\hatp_y\\right)^\\gamma \\log(\\hatp_y)
where
* :math:`y \\in \\0, \\ldots, K - 1\\` is an integer class label (:math:`K`
denotes the number of classes),
* :math:`\\hat\\mathbfp = (\\hatp_0, \\ldots, \\hatp_K-1)
\\in [0, 1]^K` is a vector representing an estimated probability
distribution over the :math:`K` classes,
* :math:`\\gamma` (gamma, not :math:`y`) is the *focusing parameter* that
specifies how much higher-confidence correct predictions contribute to
the overall loss (the higher the :math:`\\gamma`, the higher the rate at
which easy-to-classify examples are down-weighted).
The usual multiclass softmax cross-entropy loss is recovered by setting
:math:`\\gamma = 0`.
Parameters
----------
y_true : tensor-like
Integer class labels.
y_pred : tensor-like
Either probabilities or logits, depending on the `from_logits`
parameter.
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether `y_pred` contains logits or probabilities.
axis : int, optional
Channel axis in the `y_pred` tensor.
Returns
-------
:class:`tf.Tensor`
The focal loss for each example.
Examples
--------
This function computes the per-example focal loss between a one-dimensional
integer label vector and a two-dimensional prediction matrix:
>>> import numpy as np
>>> from focal_loss import sparse_categorical_focal_loss
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2)
>>> np.set_printoptions(precision=3)
>>> print(loss.numpy())
[0.009 0.032 0.082]
Warnings
--------
This function does not reduce its output to a scalar, so it cannot be passed
to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the
wrapper class :class:`~focal_loss.SparseCategoricalFocalLoss`.
References
----------
.. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for
dense object detection. IEEE Transactions on Pattern Analysis and
Machine Intelligence, 2018.
(`DOI <https://doi.org/10.1109/TPAMI.2018.2858826>`__)
(`arXiv preprint <https://arxiv.org/abs/1708.02002>`__)
See Also
--------
:meth:`~focal_loss.SparseCategoricalFocalLoss`
A wrapper around this function that makes it a
:class:`tf.keras.losses.Loss`.
"""
# Process focusing parameter
gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32)
gamma_rank = gamma.shape.rank
scalar_gamma = gamma_rank == 0
# Process class weight
if class_weight is not None:
class_weight = tf.convert_to_tensor(class_weight,
dtype=tf.dtypes.float32)
# Process prediction tensor
y_pred = tf.convert_to_tensor(y_pred)
y_pred_rank = y_pred.shape.rank
if y_pred_rank is not None:
axis %= y_pred_rank
if axis != y_pred_rank - 1:
# Put channel axis last for sparse_softmax_cross_entropy_with_logits
perm = list(itertools.chain(range(axis),
range(axis + 1, y_pred_rank), [axis]))
y_pred = tf.transpose(y_pred, perm=perm)
elif axis != -1:
raise ValueError(
f'Cannot compute sparse categorical focal loss with axis=axis on '
'a prediction tensor with statically unknown rank.')
y_pred_shape = tf.shape(y_pred)
# Process ground truth tensor
y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64)
y_true_rank = y_true.shape.rank
if y_true_rank is None:
raise NotImplementedError('Sparse categorical focal loss not supported '
'for target/label tensors of unknown rank')
reshape_needed = (y_true_rank is not None and y_pred_rank is not None and
y_pred_rank != y_true_rank + 1)
if reshape_needed:
y_true = tf.reshape(y_true, [-1])
y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]])
if from_logits:
logits = y_pred
probs = tf.nn.softmax(y_pred, axis=-1)
else:
probs = y_pred
logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON))
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y_true,
logits=logits,
)
y_true_rank = y_true.shape.rank
probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank)
if not scalar_gamma:
gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank)
focal_modulation = (1 - probs) ** gamma
loss = focal_modulation * xent_loss
if class_weight is not None:
class_weight = tf.gather(class_weight, y_true, axis=0,
batch_dims=y_true_rank)
loss *= class_weight
if reshape_needed:
loss = tf.reshape(loss, y_pred_shape[:-1])
return loss
@tf.keras.utils.register_keras_serializable()
class SparseCategoricalFocalLoss(tf.keras.losses.Loss):
r"""Focal loss function for multiclass classification with integer labels.
This loss function generalizes multiclass softmax cross-entropy by
introducing a hyperparameter :math:`\\gamma` (gamma), called the
*focusing parameter*, that allows hard-to-classify examples to be penalized
more heavily relative to easy-to-classify examples.
This class is a wrapper around
:class:`~focal_loss.sparse_categorical_focal_loss`. See the documentation
there for details about this loss function.
Parameters
----------
gamma : float or tensor-like of shape (K,)
The focusing parameter :math:`\\gamma`. Higher values of `gamma` make
easy-to-classify examples contribute less to the loss relative to
hard-to-classify examples. Must be non-negative. This can be a
one-dimensional tensor, in which case it specifies a focusing parameter
for each class.
class_weight: tensor-like of shape (K,)
Weighting factor for each of the :math:`k` classes. If not specified,
then all classes are weighted equally.
from_logits : bool, optional
Whether model prediction will be logits or probabilities.
**kwargs : keyword arguments
Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name`
or `reduction`).
Examples
--------
An instance of this class is a callable that takes a rank-one tensor of
integer class labels `y_true` and a tensor of model predictions `y_pred` and
returns a scalar tensor obtained by reducing the per-example focal loss (the
default reduction is a batch-wise average).
>>> from focal_loss import SparseCategoricalFocalLoss
>>> loss_func = SparseCategoricalFocalLoss(gamma=2)
>>> y_true = [0, 1, 2]
>>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]]
>>> loss_func(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.040919524>
Use this class in the :mod:`tf.keras` API like any other multiclass
classification loss function class that accepts integer labels found in
:mod:`tf.keras.losses` (e.g.,
:class:`tf.keras.losses.SparseCategoricalCrossentropy`:
.. code-block:: python
# Typical usage
model = tf.keras.Model(...)
model.compile(
optimizer=...,
loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss
metrics=...,
)
history = model.fit(...)
See Also
--------
:meth:`~focal_loss.sparse_categorical_focal_loss`
The function that performs the focal loss computation, taking a label
tensor and a prediction tensor and outputting a loss.
"""
def __init__(self, gamma, class_weight: Optional[Any] = None,
from_logits: bool = False, **kwargs):
super().__init__(**kwargs)
self.gamma = gamma
self.class_weight = class_weight
self.from_logits = from_logits
def get_config(self):
"""Returns the config of the layer.
A layer config is a Python dictionary containing the configuration of a
layer. The same layer can be re-instantiated later (without its trained
weights) from this configuration.
Returns
-------
dict
This layer's config.
"""
config = super().get_config()
config.update(gamma=self.gamma, class_weight=self.class_weight,
from_logits=self.from_logits)
return config
def call(self, y_true, y_pred):
"""Compute the per-example focal loss.
This method simply calls
:meth:`~focal_loss.sparse_categorical_focal_loss` with the appropriate
arguments.
Parameters
----------
y_true : tensor-like, shape (N,)
Integer class labels.
y_pred : tensor-like, shape (N, K)
Either probabilities or logits, depending on the `from_logits`
parameter.
Returns
-------
:class:`tf.Tensor`
The per-example focal loss. Reduction to a scalar is handled by
this layer's
:meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method.
"""
return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred,
class_weight=self.class_weight,
gamma=self.gamma,
from_logits=self.from_logits)
对应:focal-loss/_categorical_focal_loss.py at master · artemmavrin/focal-loss · GitHub
以上是关于Focal Loss 安装与使用 TensorFlow2.x版本的主要内容,如果未能解决你的问题,请参考以下文章
Focal Loss 安装与使用 TensorFlow2.x版本
Keras 自定义loss函数 focal loss + triplet loss