Config
CfgNode
这段代码定义了一个名为 CfgNode
的类,它继承自 fvcore.common.config.CfgNode
,并添加了一些新的功能和特性。以下是对这段代码的详细解释:
类定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| class CfgNode(_CfgNode):
"""
The same as `fvcore.common.config.CfgNode`, but different in:
1. Use unsafe yaml loading by default.
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
2. Support config versioning.
When attempting to merge an old config, it will convert the old config automatically.
.. automethod:: clone
.. automethod:: freeze
.. automethod:: defrost
.. automethod:: is_frozen
.. automethod:: load_yaml_with_base
.. automethod:: merge_from_list
.. automethod:: merge_from_other_cfg
"""
|
文档字符串
- 该类与
fvcore.common.config.CfgNode
类相同,但有两个主要不同点:
- 默认使用不安全的 YAML 加载,这可能导致任意代码执行,所以必须在加载配置文件之前手动检查文件内容。
- 支持配置版本控制。在尝试合并旧配置时,会自动转换旧配置。
- 该类还包含一些方法,如
clone
、freeze
、defrost
、is_frozen
、load_yaml_with_base
、merge_from_list
和 merge_from_other_cfg
。
类方法
1
2
3
| @classmethod
def _open_cfg(cls, filename):
return PathManager.open(filename, "r")
|
1
2
3
4
5
6
7
8
9
10
11
| def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
"""
Load content from the given config file and merge it into self.
Args:
cfg_filename: config filename
allow_unsafe: allow unsafe yaml syntax
"""
assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
loaded_cfg = type(self)(loaded_cfg)
|
merge_from_file
方法用于从指定的配置文件加载内容并将其合并到当前对象中。
allow_unsafe
参数默认为 True
,允许不安全的 YAML 语法。
- 该方法首先检查配置文件是否存在,然后使用
load_yaml_with_base
方法加载配置文件内容。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| from .defaults import _C
latest_ver = _C.VERSION
assert (
latest_ver == self.VERSION
), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
logger = logging.getLogger(__name__)
loaded_ver = loaded_cfg.get("VERSION", None)
if loaded_ver is None:
from .compat import guess_version
loaded_ver = guess_version(loaded_cfg, cfg_filename)
assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
loaded_ver, self.VERSION
)
|
- 该部分代码获取当前配置的最新版本并进行检查,确保配置对象是最新版本。
- 使用日志记录器记录相关信息。
- 获取加载配置的版本,如果未指定版本,则通过
guess_version
方法猜测版本。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| if loaded_ver == self.VERSION:
self.merge_from_other_cfg(loaded_cfg)
else:
from .compat import upgrade_config, downgrade_config
logger.warning(
"Loading an old v{} config file '{}' by automatically upgrading to v{}. "
"See docs/CHANGELOG.md for instructions to update your files.".format(
loaded_ver, cfg_filename, self.VERSION
)
)
old_self = downgrade_config(self, to_version=loaded_ver)
old_self.merge_from_other_cfg(loaded_cfg)
new_config = upgrade_config(old_self)
self.clear()
self.update(new_config)
|
- 如果加载配置的版本与当前版本相同,则直接合并配置。
- 如果版本不同,则通过
upgrade_config
和 downgrade_config
方法处理版本兼容性,将旧配置升级到新版本并合并。
方法 dump
1
2
3
4
5
6
| def dump(self, *args, **kwargs):
"""
Returns:
str: a yaml string representation of the config
"""
return super().dump(*args, **kwargs)
|
- 该方法返回配置的 YAML 字符串表示形式,以便于显示在文档中。
总结
这个类主要用于处理配置文件的加载和合并,同时支持配置版本控制,确保在合并旧版本配置时进行必要的转换。默认使用不安全的 YAML 加载方法,所以需要注意配置文件的来源和安全性。
downgrade_config
这段代码定义了一个名为 downgrade_config
的函数,用于将配置从当前版本降级到指定的较旧版本。以下是对这段代码的详细解释:
函数定义
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| def downgrade_config(cfg: CN, to_version: int) -> CN:
"""
Downgrade a config from its current version to an older version.
Args:
cfg (CfgNode):
to_version (int):
Note:
A general downgrade of arbitrary configs is not always possible due to the
different functionalities in different versions.
The purpose of downgrade is only to recover the defaults in old versions,
allowing it to load an old partial yaml config.
Therefore, the implementation only needs to fill in the default values
in the old version when a general downgrade is not possible.
"""
|
参数和注释
cfg (CfgNode)
: 需要降级的配置对象。
to_version (int)
: 目标版本号,表示希望将配置降级到的版本。
注意事项
- 一般情况下,降级任意配置并不总是可行的,因为不同版本之间可能有不同的功能。
- 降级的目的仅仅是恢复旧版本中的默认值,以便加载旧的部分 YAML 配置文件。
- 因此,当一般降级不可行时,实施只需要在旧版本中填充默认值。
函数实现
1
2
3
4
| cfg = cfg.clone()
assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
cfg.VERSION, to_version
)
|
- 首先,克隆配置对象
cfg
,以避免直接修改原始配置。
- 断言当前配置版本
cfg.VERSION
大于等于目标版本 to_version
,否则抛出异常,提示无法从高版本降级到低版本。
1
2
3
4
| for k in range(cfg.VERSION, to_version, -1):
converter = globals()["ConverterV" + str(k)]
converter.downgrade(cfg)
cfg.VERSION = k - 1
|
- 使用一个循环从当前版本逐步降级到目标版本。
- 对于每个版本
k
,获取相应的降级转换器 ConverterV{k}
。
- 调用转换器的
downgrade
方法对配置进行降级。
- 将配置的版本号
cfg.VERSION
更新为 k - 1
,表示已经成功降级到下一个版本。
总结
这个函数通过逐步调用相应版本的降级转换器,将配置对象从当前版本降级到指定的旧版本。它确保了在降级过程中填充旧版本的默认值,以便于加载旧的部分配置文件。这种降级操作并不总是可行的,因为不同版本之间的功能可能不同,因此实现只需在一般降级不可行时填充旧版本的默认值。
_RenameConverter
这段代码定义了一个名为 _RenameConverter
的类,用于处理配置项的简单重命名操作。以下是对这段代码的详细解释:
类定义
1
2
3
4
| class _RenameConverter:
"""
A converter that handles simple rename.
"""
|
- 该类名为
_RenameConverter
,主要用于处理配置项的简单重命名。
- 文档字符串说明了该类的用途,即处理简单的重命名操作。
类变量
1
| RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
|
- 定义了一个类变量
RENAME
,它是一个包含元组(旧名称,新名称)的列表。
- 这个列表用于存储需要重命名的配置项对。
类方法 upgrade
1
2
3
4
| @classmethod
def upgrade(cls, cfg: CN) -> None:
for old, new in cls.RENAME:
_rename(cfg, old, new)
|
- 这是一个类方法,用于升级配置中的项目名称。
cls
表示类本身,cfg
是配置对象。
- 方法遍历
RENAME
列表中的每个元组 (old, new)
,并调用 _rename(cfg, old, new)
函数将配置项 old
重命名为 new
。
类方法 downgrade
1
2
3
4
| @classmethod
def downgrade(cls, cfg: CN) -> None:
for old, new in cls.RENAME[::-1]:
_rename(cfg, new, old)
|
- 这是一个类方法,用于降级配置中的项目名称。
- 方法遍历
RENAME
列表中的每个元组 (old, new)
,但是顺序是反向的(即从 new
到 old
),并调用 _rename(cfg, new, old)
函数将配置项 new
重命名为 old
。
函数 _rename
虽然代码中没有显示 _rename
函数,但我们可以推测该函数的作用是执行实际的重命名操作。它可能是一个全局函数或模块中的函数,用于在配置对象 cfg
中将名称 old
更改为 new
。
总结
_RenameConverter
类的主要功能是处理配置项的简单重命名操作。
- 通过
RENAME
列表存储重命名对,upgrade
方法将旧名称重命名为新名称,downgrade
方法则将新名称重命名回旧名称。
- 这些方法通过遍历
RENAME
列表并调用 _rename
函数实现重命名操作。
_rename
这段代码定义了一个名为 _rename
的函数,用于在配置对象 cfg
中将某个配置项的名称从 old
重命名为 new
。以下是对这段代码的详细解释:
函数定义
1
| def _rename(cfg: CN, old: str, new: str) -> None:
|
cfg
: 这是一个配置对象,类型为 CN
。
old
: 旧的配置项名称,类型为字符串。
new
: 新的配置项名称,类型为字符串。
- 返回值:无(
None
)。
分割键名
1
2
| old_keys = old.split(".")
new_keys = new.split(".")
|
old_keys
和 new_keys
是通过将 old
和 new
按 .
分割得到的列表。这样可以处理嵌套的配置项名称。
内部函数 _set
1
2
3
4
5
6
7
| def _set(key_seq: List[str], val: str) -> None:
cur = cfg
for k in key_seq[:-1]:
if k not in cur:
cur[k] = CN()
cur = cur[k]
cur[key_seq[-1]] = val
|
_set
函数用于在配置对象中设置一个嵌套的配置项。
key_seq
: 配置项名称的列表(分割后的名称)。
val
: 要设置的值。
cur
初始指向 cfg
,通过遍历 key_seq
,逐层进入嵌套的字典结构。如果中间某一层不存在,就创建一个新的 CN
对象。
- 最后,将值
val
设置到最内层的键上。
内部函数 _get
1
2
3
4
5
| def _get(key_seq: List[str]) -> CN:
cur = cfg
for k in key_seq:
cur = cur[k]
return cur
|
_get
函数用于在配置对象中获取一个嵌套的配置项。
key_seq
: 配置项名称的列表。
cur
初始指向 cfg
,通过遍历 key_seq
,逐层进入嵌套的字典结构,最终返回最内层的值。
内部函数 _del
1
2
3
4
5
6
7
| def _del(key_seq: List[str]) -> None:
cur = cfg
for k in key_seq[:-1]:
cur = cur[k]
del cur[key_seq[-1]]
if len(cur) == 0 and len(key_seq) > 1:
_del(key_seq[:-1])
|
_del
函数用于在配置对象中删除一个嵌套的配置项。
key_seq
: 配置项名称的列表。
cur
初始指向 cfg
,通过遍历 key_seq
,逐层进入嵌套的字典结构,最终删除最内层的键。
- 如果删除后某一层字典为空,并且不是顶层,则递归删除该空层。
重命名操作
1
2
| _set(new_keys, _get(old_keys))
_del(old_keys)
|
- 首先,通过
_get
函数获取旧配置项的值。
- 然后,通过
_set
函数将该值设置到新的配置项位置。
- 最后,通过
_del
函数删除旧的配置项。
总结
这段代码通过三个内部函数 _set
、_get
和 _del
,实现了在配置对象 cfg
中将某个配置项从旧名称重命名为新名称的操作。具体步骤如下:
- 将旧名称对应的值获取出来。
- 将该值设置到新名称对应的位置。
- 删除旧名称对应的配置项。
ConverterV2
这段代码定义了一个名为 ConverterV2
的类,用于处理配置项的重命名和版本升级/降级。它继承自 _RenameConverter
类,并在此基础上扩展了更多的重命名规则和逻辑。以下是对这段代码的详细解释:
类定义
1
2
3
4
| class ConverterV2(_RenameConverter):
"""
A large bulk of rename, before public release.
"""
|
ConverterV2
类继承自 _RenameConverter
。
- 文档字符串说明该类用于在公开发布前进行大量的重命名操作。
重命名列表
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| RENAME = [
("MODEL.WEIGHT", "MODEL.WEIGHTS"),
("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
(
"MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
"MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
),
(
"MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
"MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
),
(
"MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
"MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
),
("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
("TEST.AUG_ON", "TEST.AUG.ENABLED"),
("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
]
|
RENAME
列表包含了旧配置项名称与新配置项名称的对应关系,用于重命名操作。
类方法 upgrade
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| @classmethod
def upgrade(cls, cfg: CN) -> None:
super().upgrade(cfg)
if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
_rename(
cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
)
_rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
else:
_rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
_rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
|
- 调用
super().upgrade(cfg)
执行基类的升级操作,即执行 RENAME
列表中的重命名。
- 根据配置项
MODEL.META_ARCHITECTURE
的值执行不同的重命名和删除操作:
- 如果
MODEL.META_ARCHITECTURE
是 RetinaNet
,则重命名 MODEL.RETINANET
的一些配置项,并删除 MODEL.RPN
的一些配置项。
- 否则,重命名
MODEL.RPN
的一些配置项,并删除 MODEL.RETINANET
的一些配置项。
- 最后,删除
MODEL.RETINANET
中的 ANCHOR_STRIDES
配置项。
类方法 downgrade
1
2
3
4
5
6
7
8
9
| @classmethod
def downgrade(cls, cfg: CN) -> None:
super().downgrade(cfg)
_rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
_rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
|
- 调用
super().downgrade(cfg)
执行基类的降级操作,即执行 RENAME
列表中的重命名。
- 将
MODEL.ANCHOR_GENERATOR
的一些配置项重命名为 MODEL.RPN
的配置项。
- 将
MODEL.RETINANET
的 ANCHOR_ASPECT_RATIOS
和 ANCHOR_SIZES
配置项设置为 MODEL.RPN
对应的配置项。
- 设置
MODEL.RETINANET.ANCHOR_STRIDES
为一个空列表(在任何版本中都没有使用过)。
总结
ConverterV2
类通过继承 _RenameConverter
并扩展 RENAME
列表,实现了大量配置项的重命名操作。它还根据配置项 MODEL.META_ARCHITECTURE
的值,在升级和降级过程中执行特定的重命名和删除操作。这样可以确保在不同版本之间转换时,配置项名称和结构的一致性。
Yaml Config System
LazyConfig
传统的基于yacs的配置系统提供基本的、标准的功能。然而,它不能为许多新项目提供足够的灵活性。Detectron2 开发了一个可替代的、non-intrusive 的配置系统,可以与detectron2或其他任何复杂的项目一起使用。
Python Syntax
LazyConfig 的配置对象仍然是字典。LazyConfig 没有使用 yaml 来定义字典,而是直接在Python中创建字典。这为用户提供了以下在 yaml 中不存在的功能:
- 使用Python轻松操作字典(添加和删除)。
- 编写简单的算术或调用简单的函数。
- 使用更多数据类型/对象。
- 使用熟悉的Python导入语法导入/编写其他配置文件。
Python配置文件可以这样加载
1
2
3
4
5
6
7
8
| # config.py:
a = dict(x=1, y=2, z=dict(xx=1))
b = dict(x=3, y=4)
# my_code.py:
from detectron2.config import LazyConfig
cfg = LazyConfig.load("path/to/config.py") # an omegaconf dictionary
assert cfg.a.z.xx == 1
|
LazyConfig.load 后,cfg
将是一个字典,其中包含配置文件全局范围内定义的所有字典。请注意:
- 在加载期间,所有字典都转换为 omegaconf 配置对象。这样就可以访问 omegaconf 的特性,比如它的访问语法和插值。
config.py
中的绝对导入与普通Python中的工作原理相同。
- 相对导入只能从配置文件中导入字典。它们只是
LazyConfig.load_rel
的语法糖。它们可以在相对路径上加载Python文件,而不需要 __init__.py
。
LazyConfig.save
可以将配置对象保存为yaml。注意,如果配置文件中出现了不可序列化的对象(例如lambdas),这并不总是成功的。这取决于用户是否要牺牲储蓄的能力来换取灵活性。
Recursive Instantiation
Logger
这段代码定义了一个名为 _ColorfulFormatter
的日志格式化类,它继承自 logging.Formatter
,并且通过在日志消息中添加颜色和样式来增强日志输出的可读性。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
# 从关键字参数中提取 root_name,并在其后加上 "." 作为分隔符
self._root_name = kwargs.pop("root_name") + "."
# 从关键字参数中提取 abbrev_name,如果不存在则默认为空字符串
self._abbrev_name = kwargs.pop("abbrev_name", "")
# 如果 abbrev_name 非空,则在其后加上 "." 作为分隔符
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
# 调用父类的构造方法进行初始化
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
# 将日志记录的名称中的 root_name 替换为 abbrev_name
record.name = record.name.replace(self._root_name, self._abbrev_name)
# 使用父类的方法格式化日志记录
log = super(_ColorfulFormatter, self).formatMessage(record)
# 根据日志记录的级别为日志添加不同的前缀
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
# 如果日志级别不是 WARNING、ERROR 或 CRITICAL,则直接返回格式化后的日志
return log
# 返回带有前缀和格式化后的日志
return prefix + " " + log
|
主要部分解释:
__init__
方法:
- 这个方法是构造函数,用于初始化
_ColorfulFormatter
类的实例。
self._root_name
:从传入的关键字参数中提取 root_name
,并在其后添加一个点(”.”)作为分隔符。
self._abbrev_name
:从传入的关键字参数中提取 abbrev_name
,如果不存在则默认为空字符串。如果 abbrev_name
非空,则在其后添加一个点(”.”)作为分隔符。
- 调用父类的构造函数对剩余的参数进行处理。
formatMessage
方法:
- 这个方法用于格式化日志记录。
- 首先将日志记录的名称(
record.name
)中的 root_name
替换为 abbrev_name
。
- 使用父类的
formatMessage
方法对日志记录进行格式化。
- 根据日志记录的级别(
record.levelno
)为日志添加不同的前缀:
- 如果是警告级别(
logging.WARNING
),前缀为红色闪烁的 “WARNING”。
- 如果是错误级别(
logging.ERROR
)或严重错误级别(logging.CRITICAL
),前缀为红色闪烁并带有下划线的 “ERROR”。
- 对于其他级别的日志,直接返回格式化后的日志。
- 最终返回带有前缀和格式化后的日志。
这段代码的目的是在日志输出中添加颜色和样式,以便在终端或控制台中更容易识别不同级别的日志消息。
setup_logger
这段代码定义了一个名为 setup_logger
的函数,用于初始化和配置日志记录器。该函数使用了 functools.lru_cache
装饰器,以确保多次调用 setup_logger
时不会添加多个处理程序(handlers)。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
| @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
output=None,
distributed_rank=0,
*,
color=True,
name="detectron2",
abbrev_name=None,
enable_propagation: bool = False,
configure_stdout: bool = True
):
"""
Initialize the detectron2 logger and set its verbosity level to "DEBUG".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
enable_propagation (bool): whether to propagate logs to the parent logger.
configure_stdout (bool): whether to configure logging to stdout.
Returns:
logging.Logger: a logger
"""
# 获取或创建名为 `name` 的日志记录器
logger = logging.getLogger(name)
# 设置日志记录器的级别为 DEBUG
logger.setLevel(logging.DEBUG)
# 设置日志记录器是否将日志传播到父记录器
logger.propagate = enable_propagation
# 如果 abbrev_name 未指定,则根据 name 进行默认设置
if abbrev_name is None:
abbrev_name = "d2" if name == "detectron2" else name
# 创建一个基本的日志格式化器
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# 配置 stdout 日志处理程序:仅在主进程中(distributed_rank == 0)进行配置
if configure_stdout and distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
# 根据 color 参数决定是否使用带颜色的格式化器
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# 配置文件日志处理程序:适用于所有进程
if output is not None:
# 确定日志文件名
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
# 创建文件日志处理程序并设置格式化器
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
|
主要部分解释:
@functools.lru_cache()
装饰器:
- 这个装饰器用于缓存函数调用的结果,确保多次调用
setup_logger
函数不会重复添加日志处理程序(handlers)。
setup_logger
函数:
- 这个函数用于初始化和配置一个名为
name
的日志记录器,并设置其日志级别为 DEBUG
。
- 参数:
output
:日志文件的名称或保存日志的目录。如果为 None
,则不保存日志文件。
distributed_rank
:分布式训练中进程的排名,仅主进程会配置 stdout
日志处理程序。
color
:是否在日志输出中使用颜色。
name
:日志记录器的根模块名称。
abbrev_name
:模块的缩写名,避免在日志中出现长名称。
enable_propagation
:是否将日志传播到父记录器。
configure_stdout
:是否配置 stdout
日志处理程序。
- 日志处理程序配置:
- stdout 日志处理程序:
- 仅在主进程中配置(
distributed_rank == 0
)。
- 根据
color
参数决定是否使用带颜色的格式化器 _ColorfulFormatter
。
- 文件日志处理程序:
- 适用于所有进程。
- 确定日志文件名并创建相应的目录。
- 创建文件日志处理程序并设置基本格式化器
plain_formatter
。
通过这个函数,可以方便地设置和管理日志记录器的配置,确保在分布式环境中各个进程的日志记录保持一致,并且能够根据需要输出到控制台和日志文件。
functools.lru_cache
是 Python 标准库中的一个装饰器,用于缓存函数的返回值,从而提高函数的性能,特别是对于一些昂贵的或频繁调用的函数。LRU 代表“Least Recently Used”,即“最近最少使用”的意思。这个缓存机制会在缓存空间满时丢弃最久未使用的结果。
基本使用方法
1
2
3
4
5
6
| import functools
@functools.lru_cache(maxsize=128)
def expensive_function(x):
# 一些耗时的计算
return x * x
|
主要特性和参数
maxsize
参数:
maxsize
指定缓存的最大容量。当缓存达到此容量时,最久未使用的条目将被移除。如果 maxsize
设置为 None
,缓存将无限制增长。
- 例如,
@functools.lru_cache(maxsize=128)
会缓存最多 128 个函数结果。
typed
参数:
typed
参数决定是否将不同类型的参数分别缓存。如果设置为 True
,则对于相同的函数参数,但类型不同的情况,会分别缓存。
- 例如,
@functools.lru_cache(typed=True)
对于 func(3)
和 func(3.0)
会缓存两次,虽然它们的值相等,但类型不同。
主要功能
- 缓存返回值:
- 当装饰的函数被调用时,
lru_cache
会检查是否已经缓存了该参数的返回值。如果有缓存值,就直接返回该值,而不需要再次计算。
- 这对于那些需要重复计算且计算量较大的函数非常有用,可以显著提高性能。
- 减少重复计算:
- 通过缓存,可以避免对相同的输入进行多次相同的计算,从而提高效率。
- LRU 机制:
- 当缓存达到最大容量时,LRU 机制会移除最久未使用的缓存条目,以便为新的条目腾出空间。
示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| import functools
import time
@functools.lru_cache(maxsize=3)
def slow_function(n):
time.sleep(n)
return n
print(slow_function(3)) # 首次调用,耗时 3 秒
print(slow_function(3)) # 再次调用,立即返回结果
print(slow_function(2)) # 首次调用,耗时 2 秒
print(slow_function(1)) # 首次调用,耗时 1 秒
print(slow_function(3)) # 缓存中已有,立即返回结果
print(slow_function(4)) # 首次调用,耗时 4 秒
# 此时缓存中应有 (3, 2, 4) 或 (3, 1, 4),因为 maxsize=3
# 调用 slow_function(2) 会导致 slow_function(1) 被移除
|
高级功能
- 缓存信息:
lru_cache
提供了一些方法来获取缓存的信息和管理缓存,比如 cache_info()
和 cache_clear()
。
1
2
| print(slow_function.cache_info()) # 显示缓存的详细信息
slow_function.cache_clear() # 清空缓存
|
functools.lru_cache
是一个非常强大的工具,适用于需要优化性能的场景,特别是在处理重复计算和提高函数调用效率时。
在 setup_logger
函数中使用 functools.lru_cache
装饰器的主要原因是为了避免在多次调用 setup_logger
时重复添加日志处理程序(handlers),从而防止日志重复输出的问题。以下是具体原因和使用 lru_cache
的好处:
1. 避免重复添加日志处理程序
在日志记录系统中,每个日志处理程序(handler)负责将日志记录输出到某个特定的目标(如控制台、文件等)。如果多次调用 setup_logger
而不加以控制,每次调用都会向同一个日志记录器添加新的处理程序,导致相同的日志消息被重复输出多次。这不仅会造成日志混乱,还会影响程序的性能。
2. 提高性能
lru_cache
通过缓存函数的返回值,可以避免多次重复执行相同的初始化过程。对于 setup_logger
函数来说,使用 lru_cache
可以确保只有在必要时才进行日志记录器的配置,从而提高性能。
3. 保持单例模式
在许多应用中,特别是大型项目中,日志记录器通常是全局单例模式的。在这种情况下,确保日志记录器只被初始化一次是很重要的。lru_cache
可以确保同一组参数只初始化一个日志记录器实例,从而实现单例模式。
具体示例
假设 setup_logger
没有使用 lru_cache
,并且在代码的多个地方调用了该函数:
1
2
| logger1 = setup_logger(output="log.txt")
logger2 = setup_logger(output="log.txt")
|
如果没有缓存,每次调用 setup_logger
都会向同一个日志记录器添加处理程序,导致日志消息被重复输出。
使用 lru_cache
后:
1
2
3
4
5
6
7
8
9
10
| import functools
@functools.lru_cache()
def setup_logger(output=None, distributed_rank=0, *, color=True, name="detectron2", abbrev_name=None, enable_propagation: bool = False, configure_stdout: bool = True):
# logger setup code
...
return logger
logger1 = setup_logger(output="log.txt")
logger2 = setup_logger(output="log.txt")
|
由于 lru_cache
缓存了第一次调用的结果,后续对 setup_logger
的调用将直接返回缓存的日志记录器实例,而不会重复添加处理程序。
结论
通过使用 functools.lru_cache
装饰器,setup_logger
可以有效地避免日志处理程序重复添加的问题,确保日志记录器的配置只执行一次,从而保持日志输出的正确性和程序的性能。这对于大型项目或需要频繁初始化日志记录器的场景尤为重要。
_cached_log_stream
这个函数用于缓存日志文件对象,以确保不同调用 setup_logger
时,如果使用相同的文件名,能够安全地写入同一个文件。
1
2
3
4
5
6
7
8
| # cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# use 1K buffer if writing to cloud storage
io = PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
atexit.register(io.close)
return io
|
@functools.lru_cache(maxsize=None)
:
- 这个装饰器用于缓存函数调用的结果。
maxsize=None
表示缓存大小没有限制,所有不同的文件名都会被缓存。
- 这样做的目的是确保对于相同的文件名,始终返回相同的文件对象,从而避免多次打开同一个文件导致的资源浪费和文件竞争问题。
PathManager.open(filename, "a", buffering=_get_log_stream_buffer_size(filename))
:
PathManager.open
用于打开指定的文件名,以追加模式 ("a"
) 打开文件。
buffering=_get_log_stream_buffer_size(filename)
指定缓冲区大小,具体缓冲区大小由 _get_log_stream_buffer_size
函数决定。
atexit.register(io.close)
:
atexit.register
用于在程序退出时自动关闭打开的文件对象 io
,以确保资源被正确释放。
return io
:
_get_log_stream_buffer_size
这个辅助函数根据文件名确定缓冲区大小。如果文件是本地文件,则不需要额外的缓冲;如果是远程文件,则使用较大的缓冲区以避免频繁的小写操作。
1
2
3
4
5
6
7
8
| def _get_log_stream_buffer_size(filename: str) -> int:
if "://" not in filename:
# Local file, no extra caching is necessary
return -1
# Remote file requires a larger cache to avoid many small writes.
if D2_LOG_BUFFER_SIZE_KEY in os.environ:
return int(os.environ[D2_LOG_BUFFER_SIZE_KEY])
return DEFAULT_LOG_BUFFER_SIZE
|
if "://" not in filename:
:
- 检查文件名是否包含 “://”,以此判断文件是否为本地文件。
- 如果文件是本地文件,则返回
-1
,表示不需要额外的缓冲。
if D2_LOG_BUFFER_SIZE_KEY in os.environ:
:
- 检查环境变量中是否存在
D2_LOG_BUFFER_SIZE_KEY
。如果存在,则使用该环境变量指定的缓冲区大小。
return DEFAULT_LOG_BUFFER_SIZE
:
- 如果上述条件都不满足,则使用默认的缓冲区大小
DEFAULT_LOG_BUFFER_SIZE
。
_find_caller
这段代码定义了一个名为 _find_caller
的函数,用于查找调用者的模块名称和唯一标识符。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "detectron2"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
|
主要部分解释:
- 函数签名和文档字符串:
- 函数名为
_find_caller
,没有参数。
- 文档字符串描述了返回值:
str
:调用者的模块名称。
tuple
:用于标识不同调用者的可哈希键。
frame = sys._getframe(2)
:
sys._getframe
是 Python 内置模块 sys
中的一个函数,返回当前或上层调用堆栈中的帧对象。
sys._getframe(2)
获取调用该函数的函数的上上层帧对象,即跳过两层调用。
while frame:
:
- 这是一个循环,用于遍历调用堆栈中的每一帧,直到找到符合条件的调用者或堆栈结束。
code = frame.f_code
:
frame.f_code
获取帧对象对应的代码对象。
if os.path.join("utils", "logger.") not in code.co_filename:
:
- 检查当前帧的代码文件名(
code.co_filename
)中是否包含指定的路径片段 "utils/logger."
。
- 如果不包含,表示找到了调用者的帧。
mod_name = frame.f_globals["__name__"]
:
- 获取调用者模块的名称(即全局变量
__name__
的值)。
if mod_name == "__main__":
:
- 检查模块名称是否为
__main__
,如果是,则将其修改为 "detectron2"
。这是因为在直接运行脚本时,模块名称会被设置为 __main__
。
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
:
- 返回调用者的模块名称和一个元组。这个元组包含调用者的文件名、行号和函数名,用于唯一标识不同的调用者。
frame = frame.f_back
:
目的和用途
这个函数的主要用途是在日志记录或调试工具中,用于查找并标识调用特定函数的模块和位置。通过返回调用者的模块名称和一个唯一标识符,可以帮助在复杂的代码库中跟踪日志记录或错误的来源。
示例
假设在某个模块中调用了这个函数:
1
2
3
4
5
6
7
8
| import sys
import os
def test_function():
caller_info = _find_caller()
print(caller_info)
test_function()
|
输出可能类似于:
1
| ('example_module', ('/path/to/example_module.py', 10, 'test_function'))
|
这表示调用 test_function
的模块是 example_module
,且调用发生在文件 /path/to/example_module.py
的第 10 行,函数名为 test_function
。
通过这种方式,可以准确定位调用者的具体位置,便于日志记录和问题排查。
log_first_n
这段代码定义了一个名为 log_first_n
的函数,用于在前 n
次调用时记录日志。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
|
主要部分解释:
- 函数签名和文档字符串:
log_first_n
函数有四个参数:
lvl
:日志级别(如 logging.INFO
)。
msg
:日志消息。
n
:指定记录日志的最大次数,默认为 1。
name
:使用的日志记录器名称,如果未指定,将使用调用者的模块名。
key
:用于识别重复日志的方法,可以是 “caller”、”message” 或二者的组合。
- 处理
key
参数:
- 如果
key
是字符串,则将其转换为元组,以便后续处理。
- 确保
key
至少包含一个元素。
- 查找调用者信息:
- 调用
_find_caller()
获取调用者的模块名称和唯一标识符。
caller_module
:调用者的模块名称。
caller_key
:一个包含文件名、行号和函数名的元组,用于唯一标识调用者。
- 生成
hash_key
:
- 根据
key
参数确定如何生成 hash_key
,用于标识日志消息的唯一性:
- 如果
key
包含 “caller”,则将 caller_key
添加到 hash_key
中。
- 如果
key
包含 “message”,则将 msg
添加到 hash_key
中。
- 这样可以确保根据调用者或消息内容来识别是否是重复日志。
- 记录日志次数:
- 使用
_LOG_COUNTER
记录每个 hash_key
的日志次数。
- 如果某个
hash_key
的日志次数未超过 n
,则记录日志。
- 日志记录:
- 使用
logging.getLogger
获取或创建指定名称的日志记录器。
- 记录日志消息,使用传入的日志级别
lvl
。
_LOG_COUNTER
变量
在代码中提到的 _LOG_COUNTER
应该是一个全局的计数器对象,用于记录每个 hash_key
的日志次数。通常可以使用 Counter()
来实现:
1
| _LOG_COUNTER = Counter()
|
目的和用途
log_first_n
函数的主要用途是在复杂系统中控制日志输出,避免过多的重复日志消息。这对于调试和监控系统状态非常有用,可以减少日志的冗余信息,使日志更有意义。
示例
假设我们在代码中调用 log_first_n
函数:
1
| log_first_n(logging.INFO, "This is a test message", n=3, key="message")
|
- 第一次调用:日志消息 “This is a test message” 会被记录。
- 第二次调用:日志消息仍会被记录。
- 第三次调用:日志消息仍会被记录。
- 第四次及之后的调用:日志消息不会被记录,因为已达到最大次数
n=3
。
这种控制可以有效地防止日志文件中充斥大量重复的日志消息,从而使日志文件更加简洁和易于阅读。
log_every_n
这段代码定义了一个名为 log_every_n
的函数,用于每隔 n
次调用记录一次日志。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
| def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
|
主要部分解释:
- 函数签名和文档字符串:
log_every_n
函数有四个参数:
lvl
:日志级别(如 logging.INFO
)。
msg
:日志消息。
n
:指定每隔多少次记录一次日志,默认为 1。
name
:使用的日志记录器名称,如果未指定,将使用调用者的模块名。
- 查找调用者信息:
- 调用
_find_caller()
获取调用者的模块名称和唯一标识符:
caller_module
:调用者的模块名称。
key
:一个包含文件名、行号和函数名的元组,用于唯一标识调用者。
- 更新和检查计数器:
_LOG_COUNTER[key] += 1
:使用 _LOG_COUNTER
记录每个 key
的调用次数,并将其加 1。
if n == 1 or _LOG_COUNTER[key] % n == 1
:检查是否需要记录日志:
- 如果
n
等于 1,每次都会记录日志。
- 如果
n
大于 1,当调用次数对 n
取模等于 1 时记录日志,即每隔 n
次记录一次日志。
- 日志记录:
- 使用
logging.getLogger
获取或创建指定名称的日志记录器。
- 记录日志消息,使用传入的日志级别
lvl
。
目的和用途
log_every_n
函数的主要用途是在复杂系统中控制日志输出,使得日志只在每隔一定次数调用时记录一次。这对于避免日志的频繁输出、减少日志文件的大小以及提升日志的可读性非常有用。
示例
假设我们在代码中调用 log_every_n
函数:
1
| log_every_n(logging.INFO, "This is a test message", n=3)
|
- 第一次调用:日志消息 “This is a test message” 会被记录。
- 第二次调用:日志消息不会被记录。
- 第三次调用:日志消息不会被记录。
- 第四次调用:日志消息 “This is a test message” 会被记录。
- 第五次调用:日志消息不会被记录。
- 第六次调用:日志消息不会被记录。
- 第七次调用:日志消息 “This is a test message” 会被记录。
通过这种方式,可以有效地控制日志的频率,使得日志信息更加有序和易于管理。
log_every_n_second
这段代码定义了一个名为 log_every_n_seconds
的函数,用于确保日志记录的频率不超过每 n
秒一次。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int): minimum interval in seconds between log messages
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
|
主要部分解释:
- 函数签名和文档字符串:
log_every_n_seconds
函数有四个参数:
lvl
:日志级别(如 logging.INFO
)。
msg
:日志消息。
n
:指定日志记录之间的最小间隔时间,单位为秒,默认为 1 秒。
name
:使用的日志记录器名称,如果未指定,将使用调用者的模块名。
- 查找调用者信息:
- 调用
_find_caller()
获取调用者的模块名称和唯一标识符:
caller_module
:调用者的模块名称。
key
:一个包含文件名、行号和函数名的元组,用于唯一标识调用者。
- 检查和更新日志记录时间:
last_logged = _LOG_TIMER.get(key, None)
:从 _LOG_TIMER
中获取上次记录日志的时间。如果没有记录则返回 None
。
current_time = time.time()
:获取当前时间(以秒为单位)。
- 条件检查和日志记录:
if last_logged is None or current_time - last_logged >= n
:检查是否需要记录日志:
- 如果
last_logged
为 None
,表示这是第一次记录日志。
- 如果当前时间与上次记录日志的时间差大于等于
n
秒,则记录日志。
logging.getLogger(name or caller_module).log(lvl, msg)
:使用指定的日志记录器记录日志消息,日志级别为 lvl
。
_LOG_TIMER[key] = current_time
:更新 _LOG_TIMER
中该调用者的日志记录时间。
目的和用途
log_every_n_seconds
函数的主要用途是在复杂系统中控制日志输出频率,确保在指定的时间间隔内只记录一次日志。这对于减少日志文件的大小和避免频繁的日志写入非常有用。
示例
假设我们在代码中调用 log_every_n_seconds
函数:
1
| log_every_n_seconds(logging.INFO, "This is a test message", n=10)
|
- 第一次调用:日志消息 “This is a test message” 会被记录。
- 接下来的 10 秒内的调用:日志消息不会被记录。
- 10 秒之后的调用:日志消息 “This is a test message” 会再次被记录。
通过这种方式,可以有效地控制日志的记录频率,使得日志信息更加有序和易于管理,同时减少了日志的冗余。
create_small_table
这段代码定义了一个名为 create_small_table
的函数,用于使用字典的键作为表头创建一个小表格。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| def create_small_table(small_dict):
"""
Create a small table using the keys of small_dict as headers. This is only
suitable for small dictionaries.
Args:
small_dict (dict): a result dictionary of only a few items.
Returns:
str: the table as a string.
"""
keys, values = tuple(zip(*small_dict.items()))
table = tabulate(
[values],
headers=keys,
tablefmt="pipe",
floatfmt=".3f",
stralign="center",
numalign="center",
)
return table
|
主要部分解释:
- 函数签名和文档字符串:
create_small_table
函数有一个参数:
- 文档字符串描述了函数的用途、参数和返回值:
- 该函数使用字典的键作为表头创建一个小表格。
- 参数
small_dict
是一个只有少量项目的结果字典。
- 返回值是一个字符串,表示表格。
- 提取键和值:
keys, values = tuple(zip(*small_dict.items()))
:
small_dict.items()
返回一个包含字典项(键值对)的视图。
zip(*small_dict.items())
将键值对分解为两个独立的元组,一个包含所有键,另一个包含所有值。
keys
包含字典的所有键,values
包含字典的所有值。
- 生成表格:
tabulate
函数用于生成格式化的表格字符串:
[values]
将值封装在一个列表中,因为 tabulate
期望的输入是一个二维数组。
headers=keys
指定表头为字典的键。
tablefmt="pipe"
指定表格的格式为管道样式(适用于 Markdown)。
floatfmt=".3f"
指定浮点数的格式为小数点后三位。
stralign="center"
将字符串居中对齐。
numalign="center"
将数字居中对齐。
- 返回表格:
示例
假设我们有一个小字典并调用 create_small_table
函数:
1
2
3
4
5
6
7
8
9
10
11
| from tabulate import tabulate
small_dict = {
"Name": "Alice",
"Age": 30,
"Height": 165.5,
"Weight": 68.2
}
table_str = create_small_table(small_dict)
print(table_str)
|
输出可能类似于:
1
2
3
| | Name | Age | Height | Weight |
|:------:|:-----:|:--------:|:------:|
| Alice | 30.000| 165.500 | 68.200 |
|
tabulate
函数
tabulate
是一个第三方库中的函数,用于生成格式化的表格。常用参数包括:
headers
:指定表头。
tablefmt
:指定表格格式(如 “pipe”、”grid”、”html” 等)。
floatfmt
:指定浮点数格式。
stralign
:指定字符串对齐方式。
numalign
:指定数字对齐方式。
通过使用 tabulate
函数,可以方便地生成各种格式的表格,适用于不同的展示需求。
目的和用途
create_small_table
函数的主要用途是在需要以表格形式展示少量数据时提供一个简洁的方法。特别是在报告生成、数据分析或日志记录中,当需要展示字典内容时,该函数可以快速生成易读的表格格式。
_log_api_usage
这段代码定义了一个名为 _log_api_usage
的函数,用于记录不同 detectron2
组件的使用情况。该函数主要用于在 Facebook 的内部基础设施中进行日志记录。以下是对代码的详细解释:
1
2
3
4
5
6
| def _log_api_usage(identifier: str):
"""
Internal function used to log the usage of different detectron2 components
inside facebook's infra.
"""
torch._C._log_api_usage_once("detectron2." + identifier)
|
主要部分解释:
- 函数签名和文档字符串:
def _log_api_usage(identifier: str)
:
- 函数名为
_log_api_usage
,参数 identifier
是一个字符串,用于标识 detectron2
组件。
- 文档字符串描述了函数的用途:
- 这是一个内部函数,用于记录在 Facebook 的基础设施中不同
detectron2
组件的使用情况。
- 日志记录:
torch._C._log_api_usage_once("detectron2." + identifier)
:
- 这个语句调用了 PyTorch 内部的一个 C++ 扩展函数
_log_api_usage_once
。
torch._C
是 PyTorch 的内部 C++ 模块,_log_api_usage_once
是其中一个用于记录 API 使用情况的函数。
"detectron2." + identifier
生成一个字符串,用于标识具体的 detectron2
组件。这个字符串由前缀 "detectron2."
和传入的 identifier
组成。
目的和用途
这个函数的主要用途是在内部基础设施中记录 detectron2
组件的使用情况。这对于监控和分析不同组件的使用频率和模式非常有用,特别是在大型项目中,可以帮助开发团队了解哪些功能被频繁使用,哪些功能需要改进或优化。
示例
假设我们有以下调用:
1
| _log_api_usage("model_loading")
|
这会记录一个标识符 "detectron2.model_loading"
,表示 detectron2
中的模型加载功能被使用了一次。这种日志记录有助于在后台系统中进行数据分析和使用情况跟踪。
总结
- 功能:记录
detectron2
组件的使用情况。
- 参数:
identifier
,用于标识具体的 detectron2
组件。
- 内部调用:使用了 PyTorch 的内部函数
_log_api_usage_once
进行日志记录。
通过这种方式,可以方便地跟踪和分析 detectron2
组件的使用情况,为优化和改进提供数据支持。
default_setup
这段代码定义了一个函数 default_setup
,用于在作业开始时进行一些基本的常见设置。它涉及日志记录、环境信息的收集、配置文件的备份以及一些与分布式训练相关的设置。下面是对代码的逐行解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
| def default_setup(cfg, args):
"""
Perform some basic common setups at the beginning of a job, including:
1. Set up the detectron2 logger
2. Log basic information about environment, cmdline arguments, and config
3. Backup the config to the output directory
Args:
cfg (CfgNode or omegaconf.DictConfig): the full config to be used
args (argparse.NameSpace): the command line arguments to be logged
"""
output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
)
)
if comm.is_main_process() and output_dir:
# Note: some of our scripts may expect the existence of
# config.yaml in output directory
path = os.path.join(output_dir, "config.yaml")
if isinstance(cfg, CfgNode):
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
else:
LazyConfig.save(cfg, path)
logger.info("Full config saved to {}".format(path))
# make sure each worker has a different, yet deterministic seed if specified
seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
seed_all_rng(None if seed < 0 else seed + rank)
# cudnn benchmark has large overhead. It shouldn't be used considering the small size of
# typical validation set.
if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = _try_get_key(
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
)
|
详细解释
- 函数说明
cfg
:配置对象,可以是 CfgNode
或 omegaconf.DictConfig
。
args
:命令行参数对象。
- 获取输出目录并创建目录
1
2
3
| output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
if comm.is_main_process() and output_dir:
PathManager.mkdirs(output_dir)
|
- 从配置中获取输出目录。
- 如果是主进程且输出目录存在,则创建输出目录。
- 设置日志记录器
1
2
3
| rank = comm.get_rank()
setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
|
- 记录基本信息
1
2
3
4
5
6
7
8
9
10
| logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
logger.info("Environment info:\n" + collect_env_info())
logger.info("Command line arguments: " + str(args))
if hasattr(args, "config_file") and args.config_file != "":
logger.info(
"Contents of args.config_file={}:\n{}".format(
args.config_file,
_highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
)
)
|
- 记录当前进程的 rank 和 world size。
- 记录环境信息。
- 记录命令行参数。
- 如果命令行参数中包含配置文件,则记录配置文件内容。
- 备份配置
1
2
3
4
5
6
7
8
9
| if comm.is_main_process() and output_dir:
path = os.path.join(output_dir, "config.yaml")
if isinstance(cfg, CfgNode):
logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
with PathManager.open(path, "w") as f:
f.write(cfg.dump())
else:
LazyConfig.save(cfg, path)
logger.info("Full config saved to {}".format(path))
|
- 如果是主进程且输出目录存在,将配置保存到输出目录中的
config.yaml
文件中。
- 设置随机种子
1
2
| seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
seed_all_rng(None if seed < 0 else seed + rank)
|
- 从配置中获取随机种子。
- 确保每个 worker 有不同的但确定的随机种子。
- 设置 cudnn benchmark
1
2
3
4
| if not (hasattr(args, "eval_only") and args.eval_only):
torch.backends.cudnn.benchmark = _try_get_key(
cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
)
|
- 如果不是只进行评估,设置 cudnn benchmark 的值。
在 default_setup
函数中,两次调用 setup_logger
是为了分别设置不同的日志记录器,以满足不同的日志记录需求。
1
2
| setup_logger(output_dir, distributed_rank=rank, name="fvcore")
logger = setup_logger(output_dir, distributed_rank=rank)
|
详细解释
- 第一次调用
setup_logger
1
| setup_logger(output_dir, distributed_rank=rank, name="fvcore")
|
- 这里设置了名为 “fvcore” 的日志记录器,用于记录与 fvcore 库相关的信息。
- 通过
name="fvcore"
参数,确保 fvcore 日志记录器的配置和名称独立,以便单独处理或过滤 fvcore 日志。
- 第二次调用
setup_logger
1
| logger = setup_logger(output_dir, distributed_rank=rank)
|
- 这里没有指定
name
参数,因此这是设置一个默认的日志记录器,用于记录通用日志信息。
- 这使得代码可以记录一般的日志信息,而不受特定库日志记录器(如 fvcore)的限制。
为什么需要两次调用
- 模块化日志记录:使用不同的日志记录器可以将日志分开,方便管理和调试。例如,一个专门用于记录与 fvcore 相关的信息,另一个用于记录通用信息。
- 独立配置:不同日志记录器可以有不同的配置,比如日志级别、格式等。
- 代码清晰度:通过明确区分日志记录器,可以更清楚地知道某些日志信息来自哪个部分,有助于在复杂项目中定位问题。
总结:
- 第一次调用 用于设置一个专门记录
fvcore
库信息的日志记录器。
- 第二次调用 用于设置一个默认的日志记录器,用于记录通用信息。
这种设计提高了日志记录的灵活性和可管理性,有助于在大规模和复杂项目中有效地进行日志记录和调试。
Dataset
MetaDataCatalog
DatasetCatalog
这段代码定义了一个 _DatasetCatalog
类,它继承自 UserDict
,用来存储和管理数据集的信息以及如何获取这些数据集。这个类提供了注册、获取、列出和移除数据集的功能。以下是对代码的详细解释:
类定义和文档注释
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| class _DatasetCatalog(UserDict):
"""
A global dictionary that stores information about the datasets and how to obtain them.
It contains a mapping from strings
(which are names that identify a dataset, e.g. "coco_2014_train")
to a function which parses the dataset and returns the samples in the
format of `list[dict]`.
The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details)
if used with the data loader functionalities in `data/build.py,data/detection_transform.py`.
The purpose of having this catalog is to make it easy to choose
different datasets, by just using the strings in the config.
"""
|
这段文档注释解释了 _DatasetCatalog
的用途和功能:
- 它是一个全局字典,存储了数据集的信息和如何获取这些数据集。
- 数据集名称(如 “coco_2014_train”)映射到一个函数,该函数解析数据集并返回
list[dict]
格式的样本。
- 返回的字典应该符合 Detectron2 数据集格式。
- 这个目录的目的是通过使用配置中的字符串轻松选择不同的数据集。
注册数据集的函数
1
2
3
4
5
6
7
8
9
10
| def register(self, name, func):
"""
Args:
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
func (callable): a callable which takes no arguments and returns a list of dicts.
It must return the same results if called multiple times.
"""
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
assert name not in self, "Dataset '{}' is already registered!".format(name)
self[name] = func
|
register
方法用于注册数据集。
name
是数据集的名称。
func
是一个可调用对象(函数),它不接受参数并返回一个字典列表。
- 通过断言确保
func
是可调用的,并且 name
没有被重复注册。
- 将
name
和 func
存储在字典中。
获取数据集的函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| def get(self, name):
"""
Call the registered function and return its results.
Args:
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
Returns:
list[dict]: dataset annotations.
"""
try:
f = self[name]
except KeyError as e:
raise KeyError(
"Dataset '{}' is not registered! Available datasets are: {}".format(
name, ", ".join(list(self.keys()))
)
) from e
return f()
|
get
方法用于获取数据集。
name
是数据集的名称。
- 尝试从字典中获取注册的函数并调用它。如果名称不存在,抛出
KeyError
并列出所有可用的数据集。
列出所有注册数据集的函数
1
2
3
4
5
6
7
8
| def list(self) -> List[str]:
"""
List all registered datasets.
Returns:
list[str]
"""
return list(self.keys())
|
移除数据集的函数
1
2
3
4
5
| def remove(self, name):
"""
Alias of ``pop``.
"""
self.pop(name)
|
remove
方法是 pop
方法的别名,用于移除注册的数据集。
字符串表示和重定义 __repr__
1
2
3
4
| def __str__(self):
return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys()))
__repr__ = __str__
|
__str__
方法返回包含所有注册数据集名称的字符串表示。
__repr__
被重定义为 __str__
,这样打印对象时会显示相同的信息。
总结
这个 _DatasetCatalog
类提供了一个全局字典,用于存储和管理数据集信息。通过注册函数,用户可以方便地添加、获取、列出和移除数据集。这个设计使得在配置中使用数据集名称来选择不同的数据集变得非常简单。
register_coco_instances
这段代码定义了一个 register_coco_instances
函数,用于注册一个以 COCO 格式存储的实例检测、实例分割和关键点检测的数据集。这个函数展示了如何注册新的数据集,以便在数据加载、评估、可视化或日志记录中使用。
代码详解
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| def register_coco_instances(name, metadata, json_file, image_root):
"""
Register a dataset in COCO's json annotation format for
instance detection, instance segmentation and keypoint detection.
(i.e., Type 1 and 2 in http://cocodataset.org/#format-data.
`instances*.json` and `person_keypoints*.json` in the dataset).
This is an example of how to register a new dataset.
You can do something similar to this function, to register new datasets.
Args:
name (str): the name that identifies a dataset, e.g. "coco_2014_train".
metadata (dict): extra metadata associated with this dataset. You can
leave it as an empty dict.
json_file (str): path to the json instance annotation file.
image_root (str or path-like): directory which contains all the images.
"""
assert isinstance(name, str), name
assert isinstance(json_file, (str, os.PathLike)), json_file
assert isinstance(image_root, (str, os.PathLike)), image_root
# 1. register a function which returns dicts
DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
# 2. Optionally, add metadata about this dataset,
# since they might be useful in evaluation, visualization or logging
MetadataCatalog.get(name).set(
json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
)
|
函数定义与参数
name
:字符串,标识数据集的名称,例如 “coco_2014_train”。
metadata
:字典,包含与该数据集关联的额外元数据,可以是空字典。
json_file
:字符串,指向 COCO 格式的 JSON 注释文件的路径。
image_root
:字符串或路径,包含所有图像的目录。
断言
1
2
3
| assert isinstance(name, str), name
assert isinstance(json_file, (str, os.PathLike)), json_file
assert isinstance(image_root, (str, os.PathLike)), image_root
|
- 确保
name
是字符串。
- 确保
json_file
是字符串或路径对象。
- 确保
image_root
是字符串或路径对象。
注册数据集函数
1
| DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
|
- 使用
DatasetCatalog.register
方法注册一个数据集。
name
是数据集的名称。
- 注册一个匿名函数(
lambda
),该函数调用 load_coco_json
函数并返回数据集字典。
添加元数据
1
2
3
| MetadataCatalog.get(name).set(
json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata
)
|
- 使用
MetadataCatalog.get(name).set
方法为数据集添加元数据。
json_file
和 image_root
是数据集的路径信息。
evaluator_type="coco"
指定评估器类型为 COCO。
- 使用
**metadata
添加额外的元数据。
总结
这个函数展示了如何注册一个新的 COCO 格式的数据集。通过调用 DatasetCatalog.register
,可以注册一个返回数据集字典的函数,并使用 MetadataCatalog.get(name).set
添加与数据集相关的元数据。这些注册和元数据在数据加载、评估、可视化或日志记录中非常有用。
load_coco_json
这段代码定义了一个 load_coco_json
函数,用于加载以 COCO 格式存储的实例检测、实例分割和关键点检测的 JSON 注释文件,并将其转换为 Detectron2 标准数据集字典格式。
代码详解
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
| def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None):
"""
Load a json file with COCO's instances annotation format.
Currently supports instance detection, instance segmentation,
and person keypoints annotations.
Args:
json_file (str): full path to the json file in COCO instances annotation format.
image_root (str or path-like): the directory where the images in this json file exists.
dataset_name (str or None): the name of the dataset (e.g., coco_2017_train).
When provided, this function will also do the following:
* Put "thing_classes" into the metadata associated with this dataset.
* Map the category ids into a contiguous range (needed by standard dataset format),
and add "thing_dataset_id_to_contiguous_id" to the metadata associated
with this dataset.
This option should usually be provided, unless users need to load
the original json content and apply more processing manually.
extra_annotation_keys (list[str]): list of per-annotation keys that should also be
loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints",
"category_id", "segmentation"). The values for these keys will be returned as-is.
For example, the densepose annotations are loaded in this way.
Returns:
list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See
`Using Custom Datasets </tutorials/datasets.html>`_ ) when `dataset_name` is not None.
If `dataset_name` is None, the returned `category_ids` may be
incontiguous and may not conform to the Detectron2 standard format.
Notes:
1. This function does not read the image files.
The results do not have the "image" field.
"""
|
参数说明
json_file
:COCO 实例注释格式的 JSON 文件的完整路径。
image_root
:包含 JSON 文件中的图像的目录。
dataset_name
:数据集的名称,例如 “coco_2017_train”。提供该参数时,函数会将类别信息添加到数据集的元数据中,并将类别 ID 映射到连续的范围。
extra_annotation_keys
:要加载到数据集字典中的额外注释键列表(除了默认的 “iscrowd”、”bbox”、”keypoints”、”category_id”、”segmentation” 之外)。
导入和初始化
1
2
3
4
5
6
7
8
| from pycocotools.coco import COCO
timer = Timer()
json_file = PathManager.get_local_path(json_file)
with contextlib.redirect_stdout(io.StringIO()):
coco_api = COCO(json_file)
if timer.seconds() > 1:
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds()))
|
- 导入
pycocotools.coco
模块。
- 使用
PathManager
获取 JSON 文件的本地路径。
- 使用
COCO
类加载 JSON 文件。
- 如果加载时间超过 1 秒,记录加载时间。
处理类别和 ID 映射
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| id_map = None
if dataset_name is not None:
meta = MetadataCatalog.get(dataset_name)
cat_ids = sorted(coco_api.getCatIds())
cats = coco_api.loadCats(cat_ids)
thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])]
meta.thing_classes = thing_classes
if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)):
if "coco" not in dataset_name:
logger.warning(
"""
Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you.
"""
)
id_map = {v: i for i, v in enumerate(cat_ids)}
meta.thing_dataset_id_to_contiguous_id = id_map
|
- 获取类别 ID 和名称,并将其存储在
MetadataCatalog
中。
- 如果类别 ID 不连续且不在 [1, #categories] 范围内,将其映射到连续的范围并记录警告信息。
加载图像和注释
1
2
3
| img_ids = sorted(coco_api.imgs.keys())
imgs = coco_api.loadImgs(img_ids)
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids]
|
- 获取图像 ID 并加载对应的图像信息。
- 获取每个图像的注释。
验证和处理注释
1
2
3
4
5
6
7
8
9
10
11
12
13
| total_num_valid_anns = sum([len(x) for x in anns])
total_num_anns = len(coco_api.anns)
if total_num_valid_anns < total_num_anns:
logger.warning(
f"{json_file} contains {total_num_anns} annotations, but only "
f"{total_num_valid_anns} of them match to images in the file."
)
if "minival" not in json_file:
ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format(
json_file
)
|
- 检查注释数量是否匹配图像数量,并记录警告信息。
- 验证注释 ID 是否唯一。
构建数据集字典
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
| imgs_anns = list(zip(imgs, anns))
logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file))
dataset_dicts = []
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or [])
num_instances_without_valid_segmentation = 0
for img_dict, anno_dict_list in imgs_anns:
record = {}
record["file_name"] = os.path.join(image_root, img_dict["file_name"])
record["height"] = img_dict["height"]
record["width"] = img_dict["width"]
image_id = record["image_id"] = img_dict["id"]
objs = []
for anno in anno_dict_list:
assert anno["image_id"] == image_id
assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.'
obj = {key: anno[key] for key in ann_keys if key in anno}
if "bbox" in obj and len(obj["bbox"]) == 0:
raise ValueError(
f"One annotation of image {image_id} contains empty 'bbox' value! "
"This json does not have valid COCO format."
)
segm = anno.get("segmentation", None)
if segm:
if isinstance(segm, dict):
if isinstance(segm["counts"], list):
segm = mask_util.frPyObjects(segm, *segm["size"])
else:
segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6]
if len(segm) == 0:
num_instances_without_valid_segmentation += 1
continue
obj["segmentation"] = segm
keypts = anno.get("keypoints", None)
if keypts:
for idx, v in enumerate(keypts):
if idx % 3 != 2:
keypts[idx] = v + 0.5
obj["keypoints"] = keypts
obj["bbox_mode"] = BoxMode.XYWH_ABS
if id_map:
annotation_category_id = obj["category_id"]
try:
obj["category_id"] = id_map[annotation_category_id]
except KeyError as e:
raise KeyError(
f"Encountered category_id={annotation_category_id} "
"but this id does not exist in 'categories' of the json file."
) from e
objs.append(obj)
record["annotations"] = objs
dataset_dicts.append(record)
if num_instances_without_valid_segmentation > 0:
logger.warning(
"Filtered out {} instances without valid segmentation. ".format(
num_instances_without_valid_segmentation
)
+ "There might be issues in your dataset generation process. Please "
"check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully"
)
return dataset_dicts
|
- 将图像和注释配对。
- 构建每个图像的字典,包含图像信息和注释。
- 验证并处理每个注释,确保其格式正确。
- 将每个图像的字典添加到数据集字典列表中。
- 返回构建好的数据集字典列表。
总结
这个函数从 COCO 格式的 JSON 文件加载注释,并将其转换为 Detectron2 标准数据集字典格式。它处理类别 ID 的映射,验证注释的唯一性,并根据需要添加额外的注释键。最后返回一个包含所有图像和注释信息的数据集字
get_detection_dataset_dicts
这段代码定义了一个函数 get_detection_dataset_dicts
,用于加载和准备用于实例检测/分割和语义分割的数据集字典。它可以从一个或多个数据集名称中获取数据,并根据给定的条件进行筛选和验证。以下是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
| def get_detection_dataset_dicts(
names,
filter_empty=True,
min_keypoints=0,
proposal_files=None,
check_consistency=True,
):
"""
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
Args:
names (str or list[str]): a dataset name or a list of dataset names
filter_empty (bool): whether to filter out images without instance annotations
min_keypoints (int): filter out images with fewer keypoints than
`min_keypoints`. Set to 0 to do nothing.
proposal_files (list[str]): if given, a list of object proposal files
that match each dataset in `names`.
check_consistency (bool): whether to check if datasets have consistent metadata.
Returns:
list[dict]: a list of dicts following the standard dataset dict format.
"""
if isinstance(names, str):
names = [names]
assert len(names), names
available_datasets = DatasetCatalog.keys()
names_set = set(names)
if not names_set.issubset(available_datasets):
logger = logging.getLogger(__name__)
logger.warning(
"The following dataset names are not registered in the DatasetCatalog: "
f"{names_set - available_datasets}. "
f"Available datasets are {available_datasets}"
)
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
if isinstance(dataset_dicts[0], torchdata.Dataset):
if len(dataset_dicts) > 1:
# ConcatDataset does not work for iterable style dataset.
# We could support concat for iterable as well, but it's often
# not a good idea to concat iterables anyway.
return torchdata.ConcatDataset(dataset_dicts)
return dataset_dicts[0]
for dataset_name, dicts in zip(names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
if proposal_files is not None:
assert len(names) == len(proposal_files)
# load precomputed proposals from proposal files
dataset_dicts = [
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
]
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
has_instances = "annotations" in dataset_dicts[0]
if filter_empty and has_instances:
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
if min_keypoints > 0 and has_instances:
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
if check_consistency and has_instances:
try:
class_names = MetadataCatalog.get(names[0]).thing_classes
check_metadata_consistency("thing_classes", names)
print_instances_class_histogram(dataset_dicts, class_names)
except AttributeError: # class names are not available for this dataset
pass
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
return dataset_dicts
|
代码详解:
- 函数定义与参数:
names
:数据集名称或数据集名称列表。
filter_empty
:是否过滤掉没有实例注释的图像。
min_keypoints
:过滤掉关键点数量少于 min_keypoints
的图像。设置为 0 则不进行过滤。
proposal_files
:如果提供,是与 names
中的每个数据集匹配的对象提案文件列表。
check_consistency
:是否检查数据集的元数据是否一致。
- 处理数据集名称:
1
2
3
| if isinstance(names, str):
names = [names]
assert len(names), names
|
- 如果
names
是字符串,将其转换为包含一个元素的列表。
- 确保
names
列表不为空。
- 检查数据集是否注册:
1
2
3
4
5
6
7
8
9
| available_datasets = DatasetCatalog.keys()
names_set = set(names)
if not names_set.issubset(available_datasets):
logger = logging.getLogger(__name__)
logger.warning(
"The following dataset names are not registered in the DatasetCatalog: "
f"{names_set - available_datasets}. "
f"Available datasets are {available_datasets}"
)
|
- 获取已注册的数据集名称。
- 检查
names
中的名称是否都在已注册的数据集中,如果不是,记录警告信息。
- 获取数据集字典:
1
2
3
4
5
6
| dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
if isinstance(dataset_dicts[0], torchdata.Dataset):
if len(dataset_dicts) > 1:
return torchdata.ConcatDataset(dataset_dicts)
return dataset_dicts[0]
|
- 从
DatasetCatalog
中获取每个数据集的字典。
- 如果数据集是
torchdata.Dataset
类型,并且数量超过一个,合并数据集,否则返回单个数据集。
- 检查数据集是否为空:
1
2
| for dataset_name, dicts in zip(names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
|
- 加载目标提议文件:
1
2
3
4
5
6
| if proposal_files is not None:
assert len(names) == len(proposal_files)
dataset_dicts = [
load_proposals_into_dataset(dataset_i_dicts, proposal_file)
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
]
|
- 合并数据集字典列表:
1
| dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
|
- 过滤无效数据:
1
2
3
4
5
| has_instances = "annotations" in dataset_dicts[0]
if filter_empty and has_instances:
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
if min_keypoints > 0 and has_instances:
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
|
- 检查元数据一致性:
1
2
3
4
5
6
7
| if check_consistency and has_instances:
try:
class_names = MetadataCatalog.get(names[0]).thing_classes
check_metadata_consistency("thing_classes", names)
print_instances_class_histogram(dataset_dicts, class_names)
except AttributeError:
pass
|
- 确保有有效数据:
1
2
| assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
return dataset_dicts
|
总的来说,这个函数加载并准备数据集字典,支持实例检测和分割任务。它可以根据给定的条件过滤数据,并检查数据集的元数据一致性。
Mapper
这段代码是一个函数 transform_instance_annotations
,用于对单个实例的注释(包括边框、分割、多边形、关键点等)应用变换操作。下面是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
| def transform_instance_annotations(
annotation, transforms, image_size, *, keypoint_hflip_indices=None
):
"""
Apply transforms to box, segmentation and keypoints annotations of a single instance.
It will use `transforms.apply_box` for the box, and
`transforms.apply_coords` for segmentation polygons & keypoints.
If you need anything more specially designed for each data structure,
you'll need to implement your own version of this function or the transforms.
Args:
annotation (dict): dict of instance annotations for a single instance.
It will be modified in-place.
transforms (TransformList or list[Transform]):
image_size (tuple): the height, width of the transformed image
keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`.
Returns:
dict:
the same input dict with fields "bbox", "segmentation", "keypoints"
transformed according to `transforms`.
The "bbox_mode" field will be set to XYXY_ABS.
"""
if isinstance(transforms, (tuple, list)):
transforms = T.TransformList(transforms)
# bbox is 1d (per-instance bounding box)
bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
# clip transformed bbox to image size
bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
annotation["bbox_mode"] = BoxMode.XYXY_ABS
if "segmentation" in annotation:
# each instance contains 1 or more polygons
segm = annotation["segmentation"]
if isinstance(segm, list):
# polygons
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
annotation["segmentation"] = [
p.reshape(-1) for p in transforms.apply_polygons(polygons)
]
elif isinstance(segm, dict):
# RLE
mask = mask_util.decode(segm)
mask = transforms.apply_segmentation(mask)
assert tuple(mask.shape[:2]) == image_size
annotation["segmentation"] = mask
else:
raise ValueError(
"Cannot transform segmentation of type '{}'!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict.".format(type(segm))
)
if "keypoints" in annotation:
keypoints = transform_keypoint_annotations(
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
)
annotation["keypoints"] = keypoints
return annotation
|
代码详解:
- 函数定义:
annotation
:字典,包含单个实例的注释数据,会在函数中被修改。
transforms
:变换操作,可以是 TransformList
或 Transform
列表。
image_size
:元组,表示图像的高度和宽度。
keypoint_hflip_indices
(可选):关键点水平翻转的索引数组。
- 处理变换列表:
1
2
| if isinstance(transforms, (tuple, list)):
transforms = T.TransformList(transforms)
|
如果 transforms
是元组或列表,将其转换为 TransformList
。
- 处理边框
bbox
:
1
2
3
4
| bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS)
bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0)
annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1])
annotation["bbox_mode"] = BoxMode.XYXY_ABS
|
- 将边框转换为 XYXY 绝对坐标模式。
- 应用变换并裁剪到图像大小。
- 更新注释中的边框和边框模式。
- 处理分割
segmentation
:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| if "segmentation" in annotation:
segm = annotation["segmentation"]
if isinstance(segm, list):
polygons = [np.asarray(p).reshape(-1, 2) for p in segm]
annotation["segmentation"] = [
p.reshape(-1) for p in transforms.apply_polygons(polygons)
]
elif isinstance(segm, dict):
mask = mask_util.decode(segm)
mask = transforms.apply_segmentation(mask)
assert tuple(mask.shape[:2]) == image_size
annotation["segmentation"] = mask
else:
raise ValueError(
"Cannot transform segmentation of type '{}'!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict.".format(type(segm))
)
|
- 如果分割是多边形列表,应用多边形变换。
- 如果分割是 RLE(Run-Length Encoding),解码后应用变换。
- 更新注释中的分割数据。
- 处理关键点
keypoints
:
1
2
3
4
5
| if "keypoints" in annotation:
keypoints = transform_keypoint_annotations(
annotation["keypoints"], transforms, image_size, keypoint_hflip_indices
)
annotation["keypoints"] = keypoints
|
- 如果存在关键点,调用
transform_keypoint_annotations
函数进行变换并更新注释中的关键点数据。
- 返回值:
返回经过变换的注释字典。
总的来说,这个函数根据给定的变换对单个实例的边框、分割和关键点进行相应的变换,并更新注释字典。
annotations_to_instances
这段代码定义了一个函数 annotations_to_instances
,用于将数据集字典中的实例注释转换为模型使用的 Instances
对象。具体来说,它会将实例的边框、类别、掩码和关键点转换为模型所需的格式。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
| def annotations_to_instances(annos, image_size, mask_format="polygon"):
"""
Create an :class:`Instances` object used by the models,
from instance annotations in the dataset dict.
Args:
annos (list[dict]): a list of instance annotations in one image, each
element for one instance.
image_size (tuple): height, width
Returns:
Instances:
It will contain fields "gt_boxes", "gt_classes",
"gt_masks", "gt_keypoints", if they can be obtained from `annos`.
This is the format that builtin models expect.
"""
boxes = (
np.stack(
[BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
)
if len(annos)
else np.zeros((0, 4))
)
target = Instances(image_size)
target.gt_boxes = Boxes(boxes)
classes = [int(obj["category_id"]) for obj in annos]
classes = torch.tensor(classes, dtype=torch.int64)
target.gt_classes = classes
if len(annos) and "segmentation" in annos[0]:
segms = [obj["segmentation"] for obj in annos]
if mask_format == "polygon":
try:
masks = PolygonMasks(segms)
except ValueError as e:
raise ValueError(
"Failed to use mask_format=='polygon' from the given annotations!"
) from e
else:
assert mask_format == "bitmask", mask_format
masks = []
for segm in segms:
if isinstance(segm, list):
masks.append(polygons_to_bitmask(segm, *image_size))
elif isinstance(segm, dict):
masks.append(mask_util.decode(segm))
elif isinstance(segm, np.ndarray):
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
segm.ndim
)
masks.append(segm)
else:
raise ValueError(
"Cannot convert segmentation of type '{}' to BitMasks!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict, or a binary segmentation mask "
" in a 2D numpy array of shape HxW.".format(type(segm))
)
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
)
target.gt_masks = masks
if len(annos) and "keypoints" in annos[0]:
kpts = [obj.get("keypoints", []) for obj in annos]
target.gt_keypoints = Keypoints(kpts)
return target
|
代码详解:
- 函数定义与参数:
annos
:包含一个图像中所有实例注释的字典列表。
image_size
:图像的高度和宽度。
mask_format
:掩码格式,默认为 "polygon"
。
- 边框处理:
1
2
3
4
5
6
7
8
9
| boxes = (
np.stack(
[BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos]
)
if len(annos)
else np.zeros((0, 4))
)
target = Instances(image_size)
target.gt_boxes = Boxes(boxes)
|
- 将每个实例的边框转换为 XYXY 绝对坐标格式。
- 如果没有注释,创建一个形状为
(0, 4)
的零数组。
- 将边框信息存储在
target
对象的 gt_boxes
字段中。
- 类别处理:
1
2
3
| classes = [int(obj["category_id"]) for obj in annos]
classes = torch.tensor(classes, dtype=torch.int64)
target.gt_classes = classes
|
- 提取每个实例的类别 ID 并转换为
torch.tensor
格式。
- 将类别信息存储在
target
对象的 gt_classes
字段中。
- 分割处理:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| if len(annos) and "segmentation" in annos[0]:
segms = [obj["segmentation"] for obj in annos]
if mask_format == "polygon":
try:
masks = PolygonMasks(segms)
except ValueError as e:
raise ValueError(
"Failed to use mask_format=='polygon' from the given annotations!"
) from e
else:
assert mask_format == "bitmask", mask_format
masks = []
for segm in segms:
if isinstance(segm, list):
masks.append(polygons_to_bitmask(segm, *image_size))
elif isinstance(segm, dict):
masks.append(mask_util.decode(segm))
elif isinstance(segm, np.ndarray):
assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format(
segm.ndim
)
masks.append(segm)
else:
raise ValueError(
"Cannot convert segmentation of type '{}' to BitMasks!"
"Supported types are: polygons as list[list[float] or ndarray],"
" COCO-style RLE as a dict, or a binary segmentation mask "
" in a 2D numpy array of shape HxW.".format(type(segm))
)
masks = BitMasks(
torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])
)
target.gt_masks = masks
|
- 检查是否存在分割注释,并根据
mask_format
处理多边形或位掩码格式的分割。
- 如果分割格式为多边形,尝试将其转换为
PolygonMasks
对象。
- 如果分割格式为位掩码,转换不同类型的分割注释为
BitMasks
对象。
- 将分割信息存储在
target
对象的 gt_masks
字段中。
- 关键点处理:
1
2
3
| if len(annos) and "keypoints" in annos[0]:
kpts = [obj.get("keypoints", []) for obj in annos]
target.gt_keypoints = Keypoints(kpts)
|
- 检查是否存在关键点注释。
- 提取每个实例的关键点,并存储在
target
对象的 gt_keypoints
字段中。
- 返回值:
返回包含边框、类别、掩码和关键点信息的 Instances
对象。
总的来说,这个函数将数据集中的实例注释转换为模型可以直接使用的 Instances
对象,并将其返回。
Dataloader
Sampler
Model
Backbone
FPN
RPN
输入:
输出:
RPN(FCOS)
输出:
Instances
RoI Head
Trainer
TrainerBase
这段代码定义了一个 TrainerBase
类,作为一个带有钩子的迭代训练器的基类。它假定训练在一个循环中运行,子类可以实现具体的循环逻辑。以下是代码的详细解释:
类定义及属性
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| class TrainerBase:
"""
Base class for iterative trainer with hooks.
The only assumption we made here is: the training runs in a loop.
A subclass can implement what the loop is.
We made no assumptions about the existence of dataloader, optimizer, model, etc.
Attributes:
iter(int): the current iteration.
start_iter(int): The iteration to start with.
By convention the minimum possible value is 0.
max_iter(int): The iteration to end training.
storage(EventStorage): An EventStorage that's opened during the course of training.
"""
|
TrainerBase
类是一个带有钩子的迭代训练器基类。它假定训练在一个循环中运行,具体的循环逻辑由子类实现。该类没有假设数据加载器、优化器、模型等的存在。
初始化方法
1
2
3
4
5
6
7
| def __init__(self) -> None:
self._hooks: List[HookBase] = []
self.iter: int = 0
self.start_iter: int = 0
self.max_iter: int
self.storage: EventStorage
_log_api_usage("trainer." + self.__class__.__name__)
|
_hooks
: 用于存储钩子的列表。
iter
: 当前迭代次数。
start_iter
: 开始迭代次数。
max_iter
: 最大迭代次数。
storage
: 一个 EventStorage
对象,用于在训练过程中存储事件。
_log_api_usage
: 记录 API 使用情况。
注册钩子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
"""
Register hooks to the trainer. The hooks are executed in the order
they are registered.
Args:
hooks (list[Optional[HookBase]]): list of hooks
"""
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# To avoid circular reference, hooks and trainer cannot own each other.
# This normally does not matter, but will cause memory leak if the
# involved objects contain __del__:
# See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
|
- 该方法注册钩子到训练器中,钩子按注册顺序执行。
- 确保钩子是
HookBase
的实例,并使用 weakref.proxy
避免循环引用导致的内存泄漏。
训练方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| def train(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
|
- 训练方法设置了训练的开始迭代次数和最大迭代次数。
- 使用
EventStorage
存储事件。
- 在训练过程中执行以下步骤:
before_train
:在训练开始前执行。
before_step
:在每一步前执行。
run_step
:执行每一步的训练逻辑,由子类实现。
after_step
:在每一步后执行。
after_train
:在训练结束后执行。
钩子方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| def before_train(self):
for h in self._hooks:
h.before_train()
def after_train(self):
self.storage.iter = self.iter
for h in self._hooks:
h.after_train()
def before_step(self):
self.storage.iter = self.iter
for h in self._hooks:
h.before_step()
def after_backward(self):
for h in self._hooks:
h.after_backward()
def after_step(self):
for h in self._hooks:
h.after_step()
|
这些方法在训练的不同阶段调用注册的钩子方法,以便执行相应的操作。
运行步骤方法
1
2
| def run_step(self):
raise NotImplementedError
|
- 这是一个抽象方法,子类必须实现具体的训练步骤逻辑。
保存和加载状态
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| def state_dict(self):
ret = {"iteration": self.iter}
hooks_state = {}
for h in self._hooks:
sd = h.state_dict()
if sd:
name = type(h).__qualname__
if name in hooks_state:
# TODO handle repetitive stateful hooks
continue
hooks_state[name] = sd
if hooks_state:
ret["hooks"] = hooks_state
return ret
def load_state_dict(self, state_dict):
logger = logging.getLogger(__name__)
self.iter = state_dict["iteration"]
for key, value in state_dict.get("hooks", {}).items():
for h in self._hooks:
try:
name = type(h).__qualname__
except AttributeError:
continue
if name == key:
h.load_state_dict(value)
break
else:
logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
|
state_dict
方法返回训练器的状态,包括当前迭代次数和钩子的状态。
load_state_dict
方法加载训练器的状态,包括迭代次数和钩子的状态。
总结
TrainerBase
类提供了一个带有钩子的迭代训练器的基类,假定训练在一个循环中运行,具体的循环逻辑由子类实现。通过注册钩子,可以在训练的不同阶段执行相应的操作,从而灵活地管理训练过程。
weakref.proxy(self)
是 Python 的 weakref
模块中的一种机制,用于创建一个对对象的弱引用代理。弱引用代理允许你引用一个对象而不会增加它的引用计数,从而避免循环引用导致的内存泄漏问题。下面是 weakref.proxy(self)
的详细解释:
背景
在某些情况下,两个对象之间可能会互相引用,这会导致循环引用。循环引用使得垃圾回收机制无法正确地回收这些对象,从而导致内存泄漏。
例如,在 TrainerBase
类中,训练器和钩子可能会互相引用。如果 TrainerBase
直接持有对钩子的强引用,而钩子又持有对训练器的强引用,则会产生循环引用。
使用 weakref.proxy(self)
为了避免这种循环引用问题,可以使用 weakref.proxy(self)
。这会创建一个对当前对象 self
的弱引用代理。弱引用代理不会增加对象的引用计数,因此即使对象被代理,它也能被垃圾回收。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| import weakref
class HookBase:
def __init__(self):
self.trainer = None
class TrainerBase:
def __init__(self):
self._hooks = []
def register_hooks(self, hooks):
hooks = [h for h in hooks if h is not None]
for h in hooks:
assert isinstance(h, HookBase)
# 使用弱引用代理避免循环引用
h.trainer = weakref.proxy(self)
self._hooks.extend(hooks)
|
具体解释
weakref.proxy(self)
创建了一个对 self
(即当前对象)的弱引用代理。
- 这个代理对象可以像原始对象一样被使用,但不会增加对象的引用计数。
- 当原始对象被垃圾回收时,代理对象也会自动变成无效的。
优势
- 避免循环引用:通过使用弱引用代理,可以避免循环引用导致的内存泄漏。
- 节省内存:由于弱引用不会增加对象的引用计数,可以更有效地管理内存。
- 安全性:如果代理对象在原始对象被回收后被访问,会引发一个
ReferenceError
异常,从而确保代码的安全性和稳定性。
例子
下面是一个简单的示例,演示如何使用弱引用代理来避免循环引用:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
| import weakref
class A:
def __init__(self, name):
self.name = name
class B:
def __init__(self, a):
# 使用弱引用代理
self.a = weakref.proxy(a)
# 创建对象
a = A("example")
b = B(a)
print(b.a.name) # 输出: example
# 删除原始对象
del a
try:
print(b.a.name) # 试图访问已删除的对象,抛出 ReferenceError
except ReferenceError:
print("Original object has been garbage collected")
|
在这个例子中,B
类中的 self.a
是对 A
类实例的弱引用代理。因此,即使 A
类实例被删除,B
类实例不会阻止垃圾回收。当试图访问已经被回收的对象时,会抛出 ReferenceError
异常。
总结来说,weakref.proxy(self)
用于创建对对象的弱引用代理,从而避免循环引用导致的内存泄漏,并允许更有效的内存管理。
SimpleTrainer
这段代码定义了一个 SimpleTrainer
类,它继承自 TrainerBase
,用于执行最常见类型的任务:单一损失函数、单一优化器、单一数据源的迭代优化,支持数据并行。下面是对代码的详细解释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| class SimpleTrainer(TrainerBase):
"""
A simple trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization,
optionally using data-parallelism.
It assumes that every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
All other tasks during training (checkpointing, logging, evaluation, LR schedule)
are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.
"""
|
SimpleTrainer
类用于最常见类型的任务,假设每一步都执行以下操作:
- 使用
data_loader
的数据计算损失。
- 用上述损失计算梯度。
- 使用优化器更新模型。
其它任务如检查点保存、日志记录、评估、学习率调度通过 TrainerBase
的钩子实现。
初始化方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| def __init__(
self,
model,
data_loader,
optimizer,
gather_metric_period=1,
zero_grad_before_forward=False,
async_write_metrics=False,
):
"""
Args:
model: a torch Module. Takes a data from data_loader and returns a
dict of losses.
data_loader: an iterable. Contains data to be used to call model.
optimizer: a torch optimizer.
gather_metric_period: an int. Every gather_metric_period iterations
the metrics are gathered from all the ranks to rank 0 and logged.
zero_grad_before_forward: whether to zero the gradients before the forward.
async_write_metrics: bool. If True, then write metrics asynchronously to improve
training speed
"""
super().__init__()
model.train()
self.model = model
self.data_loader = data_loader
self._data_loader_iter_obj = None
self.optimizer = optimizer
self.gather_metric_period = gather_metric_period
self.zero_grad_before_forward = zero_grad_before_forward
self.async_write_metrics = async_write_metrics
self.concurrent_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
主要属性
model
: 要训练的模型。
data_loader
: 数据加载器,提供训练数据。
optimizer
: 用于更新模型参数的优化器。
gather_metric_period
: 指定每隔多少次迭代收集一次指标。
zero_grad_before_forward
: 是否在前向传播前将梯度清零。
async_write_metrics
: 是否异步写入指标。
concurrent_executor
: 用于异步执行非关键逻辑的线程池。
run_step
方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
| def run_step(self):
"""
Implement the standard training logic described above.
"""
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
if self.zero_grad_before_forward:
self.optimizer.zero_grad()
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
if not self.zero_grad_before_forward:
self.optimizer.zero_grad()
losses.backward()
self.after_backward()
if self.async_write_metrics:
self.concurrent_executor.submit(
self._write_metrics, loss_dict, data_time, iter=self.iter
)
else:
self._write_metrics(loss_dict, data_time)
self.optimizer.step()
|
主要步骤
- 断言模型处于训练模式
1
| assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
|
- 获取数据并计算数据加载时间
1
2
| data = next(self._data_loader_iter)
data_time = time.perf_counter() - start
|
- 前向传播前清零梯度
1
2
| if self.zero_grad_before_forward:
self.optimizer.zero_grad()
|
- 计算损失
1
2
3
4
5
6
| loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
|
- 后向传播前清零梯度(如果未在前向传播前清零)
1
2
| if not self.zero_grad_before_forward:
self.optimizer.zero_grad()
|
- 后向传播
1
2
| losses.backward()
self.after_backward()
|
- 写入指标(同步或异步)
1
2
3
4
5
6
| if self.async_write_metrics:
self.concurrent_executor.submit(
self._write_metrics, loss_dict, data_time, iter=self.iter
)
else:
self._write_metrics(loss_dict, data_time)
|
- 更新模型参数
其他方法
_data_loader_iter
: 只在首次使用时创建数据加载器迭代器。
1
2
3
4
5
| @property
def _data_loader_iter(self):
if self._data_loader_iter_obj is None:
self._data_loader_iter_obj = iter(self.data_loader)
return self._data_loader_iter_obj
|
reset_data_loader
: 重置数据加载器。
1
2
3
4
5
| def reset_data_loader(self, data_loader_builder):
del self.data_loader
data_loader = data_loader_builder()
self.data_loader = data_loader
self._data_loader_iter_obj = None
|
_write_metrics
和 write_metrics
: 写入训练指标。
1
2
3
4
5
6
| def _write_metrics(self, loss_dict, data_time, prefix="", iter=None):
# ...
@staticmethod
def write_metrics(loss_dict, data_time, cur_iter, prefix=""):
# ...
|
state_dict
和 load_state_dict
: 保存和加载训练状态。
1
2
3
4
5
6
7
8
| def state_dict(self):
ret = super().state_dict()
ret["optimizer"] = self.optimizer.state_dict()
return ret
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.optimizer.load_state_dict(state_dict["optimizer"])
|
after_train
: 在训练结束后关闭异步执行器。
1
2
3
| def after_train(self):
super().after_train()
self.concurrent_executor.shutdown(wait=True)
|
总结
SimpleTrainer
类实现了一个简单但功能齐全的训练循环,包括损失计算、梯度更新、模型参数更新等,同时提供了异步写入指标和重置数据加载器的功能。这种设计使得代码清晰易懂,并且可以通过钩子灵活扩展功能。
Hook
HookBase
这段代码定义了一个 HookBase
类,用于实现可以注册到 TrainerBase
中的钩子。钩子是一些在训练过程中不同阶段执行的回调函数,用于在训练过程中插入自定义操作。以下是代码的详细解释:
类定义及文档
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| class HookBase:
"""
Base class for hooks that can be registered with :class:`TrainerBase`.
Each hook can implement 4 methods. The way they are called is demonstrated
in the following snippet:
::
hook.before_train()
for iter in range(start_iter, max_iter):
hook.before_step()
trainer.run_step()
hook.after_step()
iter += 1
hook.after_train()
Notes:
1. In the hook method, users can access ``self.trainer`` to access more
properties about the context (e.g., model, current iteration, or config
if using :class:`DefaultTrainer`).
2. A hook that does something in :meth:`before_step` can often be
implemented equivalently in :meth:`after_step`.
If the hook takes non-trivial time, it is strongly recommended to
implement the hook in :meth:`after_step` instead of :meth:`before_step`.
The convention is that :meth:`before_step` should only take negligible time.
Following this convention will allow hooks that do care about the difference
between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
function properly.
"""
|
HookBase
类是一个基类,用于实现可以注册到 TrainerBase
中的钩子。钩子在训练的不同阶段执行特定的操作。文档中描述了钩子的调用顺序和一些实现细节:
- 钩子方法
before_train
、before_step
、after_step
、after_train
将按顺序在训练的相应阶段被调用。
- 用户可以在钩子方法中通过
self.trainer
访问训练器的属性和状态。
- 建议在
after_step
而不是 before_step
中实现耗时操作,以避免影响训练效率。
属性
1
2
3
4
| trainer: "TrainerBase" = None
"""
A weak reference to the trainer object. Set by the trainer when the hook is registered.
"""
|
trainer
是对训练器对象的弱引用,避免循环引用导致的内存泄漏。
- 这个属性在钩子注册时由训练器设置。
方法
before_train
1
2
3
4
5
| def before_train(self):
"""
Called before the first iteration.
"""
pass
|
- 在第一次迭代前调用,可以在这里初始化一些训练前需要的资源或状态。
after_train
1
2
3
4
5
| def after_train(self):
"""
Called after the last iteration.
"""
pass
|
- 在最后一次迭代后调用,可以在这里进行一些清理工作或保存训练结果。
before_step
1
2
3
4
5
| def before_step(self):
"""
Called before each iteration.
"""
pass
|
- 在每次迭代前调用,可以在这里执行一些准备工作。
- 建议避免在这里执行耗时操作。
after_backward
1
2
3
4
5
| def after_backward(self):
"""
Called after the backward pass of each iteration.
"""
pass
|
- 在每次迭代的反向传播后调用,可以在这里执行一些与梯度相关的操作。
after_step
1
2
3
4
5
| def after_step(self):
"""
Called after each iteration.
"""
pass
|
- 在每次迭代后调用,可以在这里执行一些总结性工作或记录日志。
state_dict
1
2
3
4
5
6
| def state_dict(self):
"""
Hooks are stateless by default, but can be made checkpointable by
implementing `state_dict` and `load_state_dict`.
"""
return {}
|
- 返回钩子的状态,以便在训练检查点中保存。
- 钩子默认是无状态的,如果需要保存状态,可以重载此方法返回状态字典。
总结
HookBase
类提供了一个框架,用于在训练过程中插入自定义操作。通过实现钩子的不同方法,可以在训练的不同阶段执行特定的操作。这个设计使得训练过程具有高度的可扩展性和灵活性,用户可以根据需要定制训练逻辑。
CallbackHook
这段代码定义了一个 CallbackHook
类,继承自 HookBase
。CallbackHook
类允许用户通过回调函数来创建钩子,使得在训练过程的不同阶段执行用户定义的函数。以下是代码的详细解释:
类定义及文档
1
2
3
4
| class CallbackHook(HookBase):
"""
Create a hook using callback functions provided by the user.
"""
|
CallbackHook
类继承自 HookBase
,用于通过用户提供的回调函数创建钩子。
初始化方法
1
2
3
4
5
6
7
8
| def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
"""
Each argument is a function that takes one argument: the trainer.
"""
self._before_train = before_train
self._before_step = before_step
self._after_step = after_step
self._after_train = after_train
|
- 初始化方法接收四个可选的回调函数参数:
before_train
: 在训练开始前调用。
after_train
: 在训练结束后调用。
before_step
: 在每次迭代开始前调用。
after_step
: 在每次迭代结束后调用。
- 每个回调函数都接收一个参数,即训练器对象
trainer
。
回调方法
before_train
1
2
3
| def before_train(self):
if self._before_train:
self._before_train(self.trainer)
|
- 如果
before_train
回调函数存在,则在训练开始前调用它。
after_train
1
2
3
4
5
6
7
| def after_train(self):
if self._after_train:
self._after_train(self.trainer)
# The functions may be closures that hold reference to the trainer
# Therefore, delete them to avoid circular reference.
del self._before_train, self._after_train
del self._before_step, self._after_step
|
- 如果
after_train
回调函数存在,则在训练结束后调用它。
- 为了避免闭包中的循环引用,删除回调函数的引用。
before_step
1
2
3
| def before_step(self):
if self._before_step:
self._before_step(self.trainer)
|
- 如果
before_step
回调函数存在,则在每次迭代开始前调用它。
after_step
1
2
3
| def after_step(self):
if self._after_step:
self._after_step(self.trainer)
|
- 如果
after_step
回调函数存在,则在每次迭代结束后调用它。
总结
CallbackHook
类提供了一种灵活的方式,通过用户提供的回调函数来扩展训练过程的功能。每个回调函数在相应的训练阶段被调用,使得用户可以在训练的不同阶段执行特定的操作。通过删除回调函数的引用,可以避免由于闭包导致的循环引用问题。
这种设计使得训练过程更加模块化和可定制,用户可以根据需要定义在训练各个阶段执行的逻辑,而无需修改训练器的核心代码。
IterationTimer
这段代码定义了一个 IterationTimer
类,继承自 HookBase
,用于在训练过程中跟踪每次迭代所花费的时间,并在训练结束时打印摘要。IterationTimer
钩子利用 before_step
和 after_step
方法之间的时间来计算每次迭代的时间。以下是对代码的详细解释:
类定义及文档
1
2
3
4
5
6
7
8
9
10
11
| class IterationTimer(HookBase):
"""
Track the time spent for each iteration (each run_step call in the trainer).
Print a summary in the end of training.
This hook uses the time between the call to its :meth:`before_step`
and :meth:`after_step` methods.
Under the convention that :meth:`before_step` of all hooks should only
take negligible amount of time, the :class:`IterationTimer` hook should be
placed at the beginning of the list of hooks to obtain accurate timing.
"""
|
IterationTimer
类用于跟踪每次迭代所花费的时间,并在训练结束时打印摘要。
- 该钩子使用
before_step
和 after_step
方法之间的时间来计算每次迭代的时间。
- 为了获得准确的计时结果,
IterationTimer
钩子应该放在所有钩子的最前面,因为 before_step
方法应该只占用很少的时间。
初始化方法
1
2
3
4
5
6
7
8
9
10
| def __init__(self, warmup_iter=3):
"""
Args:
warmup_iter (int): the number of iterations at the beginning to exclude
from timing.
"""
self._warmup_iter = warmup_iter
self._step_timer = Timer()
self._start_time = time.perf_counter()
self._total_timer = Timer()
|
warmup_iter
:初始迭代次数,开始时忽略这些迭代的计时。
_step_timer
:用于计时每次迭代的时间。
_start_time
:用于记录训练开始的时间。
_total_timer
:用于计时整个训练过程的总时间。
before_train
方法
1
2
3
4
| def before_train(self):
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
|
- 在训练开始前调用,记录开始时间,并重置和暂停总计时器。
after_train
方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
| def after_train(self):
logger = logging.getLogger(__name__)
total_time = time.perf_counter() - self._start_time
total_time_minus_hooks = self._total_timer.seconds()
hook_time = total_time - total_time_minus_hooks
num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter
if num_iter > 0 and total_time_minus_hooks > 0:
logger.info(
"Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
num_iter,
str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
total_time_minus_hooks / num_iter,
)
)
logger.info(
"Total training time: {} ({} on hooks)".format(
str(datetime.timedelta(seconds=int(total_time))),
str(datetime.timedelta(seconds=int(hook_time))),
)
)
|
- 在训练结束后调用,计算总训练时间、钩子花费的时间,并打印训练速度和总训练时间的摘要。
before_step
方法
1
2
3
| def before_step(self):
self._step_timer.reset()
self._total_timer.resume()
|
- 在每次迭代开始前调用,重置步骤计时器并恢复总计时器。
after_step
方法
1
2
3
4
5
6
7
8
9
10
| def after_step(self):
iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1
if iter_done >= self._warmup_iter:
sec = self._step_timer.seconds()
self.trainer.storage.put_scalars(time=sec)
else:
self._start_time = time.perf_counter()
self._total_timer.reset()
self._total_timer.pause()
|
- 在每次迭代结束后调用,检查是否超过预热迭代次数。
- 如果超过,则记录步骤时间。
- 否则,重置开始时间和总计时器。
- 暂停总计时器。
总结
IterationTimer
类通过记录 before_step
和 after_step
方法之间的时间来跟踪每次迭代所花费的时间。它在训练开始时初始化计时器,并在训练结束时打印训练摘要。此钩子非常适合用于评估训练过程中的时间消耗,以帮助优化训练过程。
PeriodicWriter
这段代码定义了一个 PeriodicWriter
类,继承自 HookBase
,用于周期性地将事件写入 EventStorage
,通过调用 writer.write()
方法。这个钩子在每隔指定的迭代次数和最后一次迭代后执行。
类定义及文档
1
2
3
4
5
6
7
| class PeriodicWriter(HookBase):
"""
Write events to EventStorage (by calling ``writer.write()``) periodically.
It is executed every ``period`` iterations and after the last iteration.
Note that ``period`` does not affect how data is smoothed by each writer.
"""
|
PeriodicWriter
类用于周期性地将事件写入 EventStorage
。
- 它在每隔指定的迭代次数(
period
)和最后一次迭代后执行。
period
仅决定写入的频率,不影响每个 writer 如何平滑数据。
初始化方法
1
2
3
4
5
6
7
8
9
10
| def __init__(self, writers, period=20):
"""
Args:
writers (list[EventWriter]): a list of EventWriter objects
period (int): frequency of writing events
"""
self._writers = writers
for w in writers:
assert isinstance(w, EventWriter), w
self._period = period
|
writers
:一个包含 EventWriter
对象的列表。
period
:指定写入事件的频率(默认值为 20 次迭代)。
after_step
方法
1
2
3
4
5
6
| def after_step(self):
if (self.trainer.iter + 1) % self._period == 0 or (
self.trainer.iter == self.trainer.max_iter - 1
):
for writer in self._writers:
writer.write()
|
- 在每次迭代结束后调用。
- 检查当前迭代次数是否为
period
的倍数,或是否为最后一次迭代。
- 如果条件满足,调用每个
writer
的 write()
方法写入事件。
after_train
方法
1
2
3
4
5
6
| def after_train(self):
for writer in self._writers:
# If any new data is found (e.g. produced by other after_train),
# write them before closing
writer.write()
writer.close()
|
- 在训练结束后调用。
- 调用每个
writer
的 write()
方法写入所有剩余的事件。
- 调用每个
writer
的 close()
方法关闭 writer。
总结
PeriodicWriter
类通过 writer.write()
方法周期性地将事件写入 EventStorage
。在每隔指定的迭代次数和训练结束后,它会执行写入操作。这样设计的好处是可以定期保存训练过程中生成的事件数据,便于监控和分析训练过程。以下是关键点:
- 初始化:接受一组
EventWriter
对象和一个写入周期 period
。
- 周期性写入:在每隔
period
次迭代和最后一次迭代后执行写入操作。
- 训练结束写入:在训练结束后,确保所有事件都被写入,并关闭 writer。
这种设计确保了事件数据的定期写入和记录,提供了灵活且有效的训练过程监控机制。
Checkpointer
这段代码定义了一个 Checkpointer
类,该类用于保存和加载模型以及其他可检查点的对象(如优化器、学习率调度器等)。下面是对代码的详细解释:
类定义及文档
1
2
3
4
5
| class Checkpointer:
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
|
Checkpointer
类用于保存和加载模型及其他带有 state_dict
和 load_state_dict
方法的对象。
初始化方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: Any,
) -> None:
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the ``state_dict()`` and ``load_state_dict()`` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.checkpointables: Dict[str, Any] = {}
for k, v in checkpointables.items():
self.add_checkpointable(k, v)
self.logger: logging.Logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
self.path_manager: PathManager = PathManager()
self.path_manager.register_handler(HTTPURLHandler())
|
model
:要保存的模型。
save_dir
:保存检查点的目录。
save_to_disk
:是否保存检查点到磁盘。
checkpointables
:其他可检查点的对象(如优化器),这些对象需要实现 state_dict
和 load_state_dict
方法。
- 如果模型是
DistributedDataParallel
或 DataParallel
,则使用其内部的实际模型。
- 将
checkpointables
添加到内部字典中,并注册默认的路径管理器。
添加检查点对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| def add_checkpointable(self, key: str, checkpointable: Any) -> None:
"""
Add checkpointable object for this checkpointer to track.
Args:
key (str): the key used to save the object
checkpointable: any object with ``state_dict()`` and
``load_state_dict()`` method
"""
if key in self.checkpointables:
raise KeyError(f"Key {key} already used in the Checkpointer")
if not hasattr(checkpointable, "state_dict"):
raise TypeError(
"add_checkpointable needs an object with 'state_dict()' method."
)
self.checkpointables[key] = checkpointable
|
- 检查
key
是否已存在,防止重复。
- 确保
checkpointable
对象实现了 state_dict
方法。
- 将
checkpointable
对象添加到内部字典中。
保存检查点
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
| def save(self, name: str, **kwargs: Any) -> None:
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with self.path_manager.open(save_file, "wb") as f:
torch.save(data, cast(IO[bytes], f))
self.tag_last_checkpoint(basename)
|
- 检查是否需要保存。
- 将模型和所有检查点对象的
state_dict
保存到字典中,并包含额外的 kwargs
数据。
- 将数据保存到指定目录的文件中,并记录最后一个检查点。
加载检查点
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| def load(
self, path: str, checkpointables: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Load from the given checkpoint.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
self.logger.info("No checkpoint found. Initializing model from scratch")
return {}
self.logger.info("[Checkpointer] Loading from {} ...".format(path))
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if incompatible is not None:
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint:
self.logger.info("Loading {} from {} ...".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key))
return checkpoint
|
- 从指定路径加载检查点。
- 检查点包括模型和其他可检查点对象的状态。
- 记录不兼容的键,并加载所有匹配的检查点对象。
- 返回检查点中的额外数据。
检查点辅助方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
| def has_checkpoint(self) -> bool:
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return self.path_manager.exists(save_file)
def get_checkpoint_file(self) -> str:
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with self.path_manager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
return ""
return os.path.join(self.save_dir, last_saved)
def get_all_checkpoint_files(self) -> List[str]:
"""
Returns:
list: All available checkpoint files (.pth files) in target
directory.
"""
all_model_checkpoints = [
os.path.join(self.save_dir, file)
for file in self.path_manager.ls(self.save_dir)
if self.path_manager.isfile(os.path.join(self.save_dir, file))
and file.endswith(".pth")
]
return all_model_checkpoints
def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]:
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists
and load the model together with all the checkpointables. Otherwise
only load the model without loading any checkpointables.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
def tag_last_checkpoint(self, last_filename_basename: str) -> None:
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with self.path_manager.open(save_file, "w") as f:
f.write(last_filename_basename)
|
has_checkpoint
:检查是否存在检查点。
get_checkpoint_file
:获取最近的检查点文件。
get_all_checkpoint_files
:获取所有检查点文件。
resume_or_load
:尝试从最后一个检查点恢复或加载指定路径的检查点。
tag_last_checkpoint
:标记最后一个检查点。
加载文件和模型方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
| def _load_file(self, f: str) -> Dict[str, Any]:
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to torch.Tensor or numpy arrays.
"""
with self.path_manager.open(f, "rb") as file:
return torch.load(cast(IO[bytes], file), map_location=torch.device("cpu"))
def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains
## _PeriodicCheckpointer
这段代码定义了一个 `PeriodicCheckpointer` 类,用于定期保存检查点。它会在达到特定迭代次数或最大迭代次数时调用给定的 `checkpointer.save` 方法来保存检查点。以下是代码的详细解释:
### 类定义及文档
```python
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
Attributes:
checkpointer (Checkpointer): the underlying checkpointer object
"""
|
PeriodicCheckpointer
类用于定期保存检查点。
- 在调用
.step(iteration)
方法时,如果当前迭代次数是 period
的倍数或达到 max_iter
,则执行给定的 checkpointer.save
方法保存检查点。
初始化方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| def __init__(
self,
checkpointer: Checkpointer,
period: int,
max_iter: Optional[int] = None,
max_to_keep: Optional[int] = None,
file_prefix: str = "model",
) -> None:
"""
Args:
checkpointer: the checkpointer object used to save checkpoints.
period (int): the period to save checkpoint.
max_iter (int): maximum number of iterations. When it is reached,
a checkpoint named "{file_prefix}_final" will be saved.
max_to_keep (int): maximum number of most current checkpoints to keep,
previous checkpoints will be deleted
file_prefix (str): the prefix of checkpoint's filename
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_iter = max_iter
if max_to_keep is not None:
assert max_to_keep > 0
self.max_to_keep = max_to_keep
self.recent_checkpoints: List[str] = []
self.path_manager: PathManager = checkpointer.path_manager
self.file_prefix = file_prefix
|
checkpointer
:用于保存检查点的对象。
period
:保存检查点的周期。
max_iter
:最大迭代次数,达到时会保存一个名为 {file_prefix}_final
的检查点。
max_to_keep
:要保留的最新检查点的最大数量,超过这个数量的旧检查点会被删除。
file_prefix
:检查点文件名前缀。
step
方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| def step(self, iteration: int, **kwargs: Any) -> None:
"""
Perform the appropriate action at the given iteration.
Args:
iteration (int): the current iteration, ranged in [0, max_iter-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
iteration = int(iteration)
additional_state = {"iteration": iteration}
additional_state.update(kwargs)
if (iteration + 1) % self.period == 0:
self.checkpointer.save(
"{}_{:07d}".format(self.file_prefix, iteration), **additional_state
)
if self.max_to_keep is not None:
self.recent_checkpoints.append(self.checkpointer.get_checkpoint_file())
if len(self.recent_checkpoints) > self.max_to_keep:
file_to_delete = self.recent_checkpoints.pop(0)
if self.path_manager.exists(
file_to_delete
) and not file_to_delete.endswith(f"{self.file_prefix}_final.pth"):
self.path_manager.rm(file_to_delete)
if self.max_iter is not None:
if iteration >= self.max_iter - 1:
self.checkpointer.save(f"{self.file_prefix}_final", **additional_state)
|
step
方法根据当前迭代次数 iteration
执行适当的操作。
- 如果当前迭代次数是
period
的倍数,保存检查点。
- 如果设置了
max_to_keep
,则维护一个最近检查点的列表,超过 max_to_keep
的检查点会被删除。
- 如果当前迭代次数达到或超过
max_iter
,保存一个名为 {file_prefix}_final
的检查点。
save
方法
1
2
3
4
5
6
7
8
9
10
11
| def save(self, name: str, **kwargs: Any) -> None:
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
|
save
方法手动保存检查点,参数与 Checkpointer.save
相同。
- 可以在调度之外使用此方法手动保存检查点。
总结
PeriodicCheckpointer
类用于在训练过程中定期保存模型的检查点。它通过以下方式工作:
- 初始化:接受
checkpointer
对象、保存周期 period
、最大迭代次数 max_iter
、要保留的最新检查点数量 max_to_keep
和检查点文件名前缀 file_prefix
。
- 定期保存:在每隔
period
次迭代和达到 max_iter
时调用 checkpointer.save
方法保存检查点。
- 删除旧检查点:如果设置了
max_to_keep
,则维护一个最近检查点的列表,超过 max_to_keep
的检查点会被删除。
- 手动保存:提供一个
save
方法,允许在调度之外手动保存检查点。
这种设计确保了训练过程中的检查点管理,使得模型在训练过程中能够定期保存和恢复。
PeriodicCheckpointer
这段代码定义了一个 PeriodicCheckpointer
类,继承自 _PeriodicCheckpointer
和 HookBase
,用于在训练过程中定期保存检查点。
类定义及文档
1
2
3
4
5
6
7
8
9
10
| class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
"""
Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
Note that when used as a hook,
it is unable to save additional data other than what's defined
by the given `checkpointer`.
It is executed every ``period`` iterations and after the last iteration.
"""
|
PeriodicCheckpointer
类与 detectron2.checkpoint.PeriodicCheckpointer
类相同,但作为钩子使用。
- 作为钩子时,无法保存除给定
checkpointer
定义的数据以外的其他数据。
- 每隔
period
次迭代和最后一次迭代后执行。
before_train
方法
1
2
| def before_train(self):
self.max_iter = self.trainer.max_iter
|
- 在训练开始前调用。
- 获取并设置训练的最大迭代次数
max_iter
。
after_step
方法
1
2
3
| def after_step(self):
# No way to use **kwargs
self.step(self.trainer.iter)
|
- 在每次迭代结束后调用。
- 调用
step
方法并传递当前迭代次数 self.trainer.iter
,执行保存检查点的操作。
继承关系
_PeriodicCheckpointer
:这是一个假设的基类,通常包含实际的检查点保存逻辑。
HookBase
:提供钩子的基本接口。
主要功能
PeriodicCheckpointer
类的主要功能是在训练过程中定期保存模型的检查点。它通过定期调用 step
方法实现这一点。
示例解释
假设 _PeriodicCheckpointer
的 step
方法如下:
1
2
3
4
5
6
7
8
9
| class _PeriodicCheckpointer:
def __init__(self, checkpointer, period, max_iter):
self.checkpointer = checkpointer
self.period = period
self.max_iter = max_iter
def step(self, iteration):
if (iteration + 1) % self.period == 0 or iteration == self.max_iter - 1:
self.checkpointer.save("model_{:07d}".format(iteration))
|
PeriodicCheckpointer
类可以这样使用:
1
2
3
4
5
6
7
8
| # 创建一个假设的 checkpointer 对象
checkpointer = Checkpointer(model, optimizer, "path/to/checkpoints")
# 创建 PeriodicCheckpointer 钩子
periodic_checkpointer = PeriodicCheckpointer(checkpointer, period=100, max_iter=1000)
# 注册钩子到 Trainer 中
trainer.register_hooks([periodic_checkpointer])
|
在每隔 period
次迭代(例如每 100 次迭代)和最后一次迭代时,将调用 checkpointer.save()
方法保存模型的检查点。
总结
PeriodicCheckpointer
类作为一个钩子,用于在训练过程中定期保存模型的检查点。它继承了 _PeriodicCheckpointer
和 HookBase
,在 before_train
方法中初始化最大迭代次数,并在 after_step
方法中定期调用 step
方法保存检查点。这样设计的好处是能够轻松集成到训练过程中,并确保定期保存模型状态以便于后续的恢复和分析。
LRScheduler
这段代码定义了一个 LRScheduler
类,继承自 HookBase
。LRScheduler
类用于在每次迭代后执行一个 PyTorch 内置的学习率调度器,并总结学习率。以下是对代码的详细解释:
类定义及文档
1
2
3
4
5
| class LRScheduler(HookBase):
"""
A hook which executes a torch builtin LR scheduler and summarizes the LR.
It is executed after every iteration.
"""
|
LRScheduler
类是一个钩子,它在每次迭代后执行 PyTorch 内置的学习率调度器,并总结当前的学习率。
初始化方法
1
2
3
4
5
6
7
8
9
10
11
12
| def __init__(self, optimizer=None, scheduler=None):
"""
Args:
optimizer (torch.optim.Optimizer):
scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
in the optimizer.
If any argument is not given, will try to obtain it from the trainer.
"""
self._optimizer = optimizer
self._scheduler = scheduler
|
- 初始化方法接受一个优化器和一个学习率调度器。如果没有提供这些参数,将尝试从训练器中获取。
optimizer
:PyTorch 优化器。
scheduler
:PyTorch 学习率调度器或 ParamScheduler
对象。如果是 ParamScheduler
,则定义了优化器的基本学习率的乘数。
before_train
方法
1
2
3
4
5
6
7
8
9
10
| def before_train(self):
self._optimizer = self._optimizer or self.trainer.optimizer
if isinstance(self.scheduler, ParamScheduler):
self._scheduler = LRMultiplier(
self._optimizer,
self.scheduler,
self.trainer.max_iter,
last_iter=self.trainer.iter - 1,
)
self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer)
|
- 在训练开始前调用。
- 如果没有提供优化器,则从训练器中获取。
- 如果
scheduler
是 ParamScheduler
,则创建一个 LRMultiplier
对象。
- 通过
get_best_param_group_id
方法获取最好的参数组 ID。
get_best_param_group_id
方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| @staticmethod
def get_best_param_group_id(optimizer):
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
largest_group = max(len(g["params"]) for g in optimizer.param_groups)
if largest_group == 1:
# If all groups have one parameter,
# then find the most common initial LR, and use it for summary
lr_count = Counter([g["lr"] for g in optimizer.param_groups])
lr = lr_count.most_common()[0][0]
for i, g in enumerate(optimizer.param_groups):
if g["lr"] == lr:
return i
else:
for i, g in enumerate(optimizer.param_groups):
if len(g["params"]) == largest_group:
return i
|
- 静态方法,用于获取优化器中最好的参数组 ID。
- 通过启发式方法选择具有最多参数的参数组,如果所有参数组都有一个参数,则选择初始学习率最常见的参数组。
after_step
方法
1
2
3
4
| def after_step(self):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
self.scheduler.step()
|
- 在每次迭代后调用。
- 获取当前最好的参数组的学习率,并将其存储在训练器的存储中。
- 调用调度器的
step
方法更新学习率。
scheduler
属性
1
2
3
| @property
def scheduler(self):
return self._scheduler or self.trainer.scheduler
|
- 属性方法,返回当前的学习率调度器。如果没有提供调度器,则从训练器中获取。
state_dict
方法
1
2
3
4
| def state_dict(self):
if isinstance(self.scheduler, _LRScheduler):
return self.scheduler.state_dict()
return {}
|
- 返回调度器的状态字典,用于保存当前状态。如果调度器是
_LRScheduler
类型,则返回其状态字典。
load_state_dict
方法
1
2
3
4
5
| def load_state_dict(self, state_dict):
if isinstance(self.scheduler, _LRScheduler):
logger = logging.getLogger(__name__)
logger.info("Loading scheduler from state_dict ...")
self.scheduler.load_state_dict(state_dict)
|
- 从状态字典中加载调度器的状态。
- 如果调度器是
_LRScheduler
类型,则记录信息并加载状态字典。
总结
LRScheduler
类是一个用于在每次迭代后执行学习率调度器并总结学习率的钩子。它的主要功能包括:
- 初始化优化器和调度器。
- 在训练开始前设置优化器和调度器。
- 在每次迭代后更新和存储当前学习率。
- 支持保存和加载调度器的状态。
这种设计使得训练过程中学习率的调整和记录更加灵活和自动化,便于监控和调整模型的训练过程。
Evaluator
DDP
launch
这段代码定义了一个名为 launch
的函数,用于启动多进程或分布式训练。这个函数适用于需要在多个机器上进行训练的场景,特别是当每台机器上有多个 GPU 时。以下是对这段代码的详细解释:
函数签名
1
2
3
4
5
6
7
8
9
| def launch(
main_func,
num_gpus_per_machine,
num_machines=1,
machine_rank=0,
dist_url=None,
args=(),
timeout=DEFAULT_TIMEOUT,
):
|
main_func
: 一个函数对象,将在训练过程中被调用。
num_gpus_per_machine
(int): 每台机器上的进程数,通常为 GPU 的数量。
num_machines
(int, 默认值=1): 参与训练的总机器数。
machine_rank
(int, 默认值=0): 当前机器的排名。
dist_url
(str, 默认值=None): 用于分布式训练连接的 URL,包括协议部分,例如 "tcp://127.0.0.1:8686"
。可以设置为 "auto"
自动选择本地主机上的空闲端口。
args
(tuple, 默认值=()): 传递给 main_func
的参数。
timeout
(timedelta, 默认值=DEFAULT_TIMEOUT): 分布式工作器的超时时间。
函数文档字符串
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| """
Launch multi-process or distributed training.
This function must be called on all machines involved in the training.
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
Args:
main_func: a function that will be called by `main_func(*args)`
num_gpus_per_machine (int): number of processes per machine. When
using GPUs, this should be the number of GPUs.
num_machines (int): the total number of machines
machine_rank (int): the rank of this machine
dist_url (str): url to connect to for distributed jobs, including protocol
e.g. "tcp://127.0.0.1:8686".
Can be set to "auto" to automatically select a free port on localhost
timeout (timedelta): timeout of the distributed workers
args (tuple): arguments passed to main_func
"""
|
这个文档字符串解释了每个参数的含义以及函数的用途,即启动多进程或分布式训练。
函数主体
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| world_size = num_machines * num_gpus_per_machine
if world_size > 1:
if dist_url == "auto":
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
port = _find_free_port()
dist_url = f"tcp://127.0.0.1:{port}"
if num_machines > 1 and dist_url.startswith("file://"):
logger = logging.getLogger(__name__)
logger.warning(
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
)
mp.start_processes(
_distributed_worker,
nprocs=num_gpus_per_machine,
args=(
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args,
timeout,
),
daemon=False,
)
else:
main_func(*args)
|
world_size = num_machines * num_gpus_per_machine
:计算总进程数。
if world_size > 1
:如果总进程数大于 1,意味着需要多进程或分布式训练。
if dist_url == "auto"
:如果 dist_url
是 auto
,自动选择本地可用端口,但仅支持单机模式。
if num_machines > 1 and dist_url.startswith("file://")
:如果 num_machines
大于 1 且 dist_url
以 file://
开头,记录一个警告,因为 file://
在多机模式下不可靠。
mp.start_processes
:启动多进程,调用 _distributed_worker
函数,每台机器启动 num_gpus_per_machine
个进程,并传递必要的参数。
else
:如果总进程数为 1,直接调用 main_func
。
关键点
- 多机多进程支持:支持在多台机器上启动多进程训练,每台机器上的进程数由 GPU 数量决定。
- 自动端口选择:支持自动选择本地空闲端口,但仅限单机模式。
- 分布式训练:支持通过
tcp://
或 file://
协议进行分布式训练,但 file://
在多机模式下不推荐。
这段代码主要用于分布式机器学习训练环境的初始化和进程管理。
_distributed_worker
这段代码定义了一个名为 _distributed_worker
的函数,用于在分布式训练中初始化和管理每个子进程。以下是对这段代码的详细解释:
函数签名
1
2
3
4
5
6
7
8
9
10
| def _distributed_worker(
local_rank,
main_func,
world_size,
num_gpus_per_machine,
machine_rank,
dist_url,
args,
timeout=DEFAULT_TIMEOUT,
):
|
local_rank
: 当前进程在本地机器上的排名。
main_func
: 一个函数对象,将在训练过程中被调用。
world_size
: 总进程数。
num_gpus_per_machine
: 每台机器上的 GPU 数量,即进程数。
machine_rank
: 当前机器的排名。
dist_url
: 用于分布式训练连接的 URL。
args
: 传递给 main_func
的参数。
timeout
: 分布式工作器的超时时间。
函数主体
1
2
3
4
| has_gpu = torch.cuda.is_available()
if has_gpu:
assert num_gpus_per_machine <= torch.cuda.device_count()
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
has_gpu = torch.cuda.is_available()
: 检查当前环境是否有可用的 GPU。
- 如果有 GPU,则
assert num_gpus_per_machine <= torch.cuda.device_count()
检查每台机器上的进程数是否小于等于可用的 GPU 数量。
- 计算
global_rank
: 全局排名,等于机器排名乘以每台机器的进程数再加上本地排名。
1
2
3
4
5
6
7
8
9
10
11
12
| try:
dist.init_process_group(
backend="NCCL" if has_gpu else "GLOO",
init_method=dist_url,
world_size=world_size,
rank=global_rank,
timeout=timeout,
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Process group URL: {}".format(dist_url))
raise e
|
dist.init_process_group
初始化分布式进程组,使用的后端根据是否有 GPU 决定(NCCL
或 GLOO
),并设置世界大小、排名和超时时间。
- 如果初始化失败,捕获异常并记录错误信息,然后重新抛出异常。
1
2
3
4
| # Setup the local process group.
comm.create_local_process_group(num_gpus_per_machine)
if has_gpu:
torch.cuda.set_device(local_rank)
|
comm.create_local_process_group(num_gpus_per_machine)
设置本地进程组。
- 如果有 GPU,使用
torch.cuda.set_device(local_rank)
设置当前进程使用的 GPU。
1
2
3
4
5
| # synchronize is needed here to prevent a possible timeout after calling init_process_group
# See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
comm.synchronize()
main_func(*args)
|
comm.synchronize()
同步所有进程,防止在调用 init_process_group
后可能出现的超时问题。
- 最后,调用
main_func(*args)
执行主要的训练函数。
关键点
- 环境检查:首先检查是否有可用的 GPU,并验证每台机器的进程数不超过可用 GPU 数量。
- 全局排名:根据机器排名和本地排名计算每个进程的全局排名。
- 分布式进程组初始化:使用
torch.distributed
初始化进程组,支持 GPU 和非 GPU 环境。
- 本地进程组和设备设置:设置本地进程组,并在有 GPU 的情况下为每个进程分配 GPU。
- 同步:在初始化进程组后进行同步,防止超时。
- 调用训练函数:最后,调用主要的训练函数
main_func
,开始训练任务。
这段代码主要用于在分布式训练环境中启动和管理每个子进程,确保每个进程能够正确初始化并同步,最终执行训练任务。
create_local_process_group
这段代码定义了一个函数 create_local_process_group
,用于在每台机器上创建一个包含本地进程的进程组。在分布式训练中,这样的进程组对于处理本地任务和同步非常有用。以下是对这段代码的详细解释:
函数装饰器
- 这个装饰器来自
functools
模块,使用 lru_cache
(最近最少使用缓存)来缓存函数的返回值。这意味着,如果使用相同的参数多次调用 create_local_process_group
,它将直接返回缓存的结果,而不需要重新执行函数体。
函数定义
1
| def create_local_process_group(num_workers_per_machine: int) -> None:
|
- 定义了一个名为
create_local_process_group
的函数,参数 num_workers_per_machine
表示每台机器上的工作进程数,通常是 GPU 的数量。
函数文档字符串
1
2
3
4
5
6
7
8
9
10
11
12
13
| """
Create a process group that contains ranks within the same machine.
Detectron2's launch() in engine/launch.py will call this function. If you start
workers without launch(), you'll have to also call this. Otherwise utilities
like `get_local_rank()` will not work.
This function contains a barrier. All processes must call it together.
Args:
num_workers_per_machine: the number of worker processes per machine. Typically
the number of GPUs.
"""
|
- 解释了函数的用途:创建一个包含同一台机器内排名的进程组。
- 说明了函数的调用场景:如果未使用
launch()
启动工作进程,则需要手动调用此函数,否则某些工具(如 get_local_rank()
)将无法正常工作。
- 这个函数包含一个 barrier,所有进程必须同时调用它。
函数主体
1
2
3
4
5
6
7
8
9
10
| global _LOCAL_PROCESS_GROUP
assert _LOCAL_PROCESS_GROUP is None
assert get_world_size() % num_workers_per_machine == 0
num_machines = get_world_size() // num_workers_per_machine
machine_rank = get_rank() // num_workers_per_machine
for i in range(num_machines):
ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
pg = dist.new_group(ranks_on_i)
if i == machine_rank:
_LOCAL_PROCESS_GROUP = pg
|
global _LOCAL_PROCESS_GROUP
: 声明 _LOCAL_PROCESS_GROUP
变量为全局变量,用于存储本地进程组。
assert _LOCAL_PROCESS_GROUP is None
: 确保 _LOCAL_PROCESS_GROUP
尚未被创建。
assert get_world_size() % num_workers_per_machine == 0
: 确保总进程数可以被每台机器上的工作进程数整除。
num_machines = get_world_size() // num_workers_per_machine
: 计算总的机器数。
machine_rank = get_rank() // num_workers_per_machine
: 计算当前机器的排名。
- 使用一个循环遍历每台机器:
ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine))
: 创建一个包含当前机器上所有进程排名的列表。
pg = dist.new_group(ranks_on_i)
: 使用 dist.new_group
创建一个新的进程组。
if i == machine_rank: _LOCAL_PROCESS_GROUP = pg
: 如果当前机器是本地机器,将创建的进程组赋值给全局变量 _LOCAL_PROCESS_GROUP
。
关键点
- 缓存:通过
lru_cache
装饰器缓存函数结果,避免重复创建本地进程组。
- 全局变量:使用全局变量
_LOCAL_PROCESS_GROUP
存储本地进程组。
- 断言:确保函数在正确的条件下执行,例如
_LOCAL_PROCESS_GROUP
为空,总进程数可整除每台机器的工作进程数。
- 进程组创建:为每台机器上的进程创建一个新的进程组,并将本地机器的进程组存储起来。
这段代码在分布式训练中用于创建和管理本地进程组,确保每台机器上的进程能够正确同步和协同工作。
dist.get_world_size()
在 PyTorch 中,dist.get_world_size()
是一个用于分布式训练的函数,定义在 torch.distributed
模块中。这个函数返回当前默认进程组中的总进程数(即世界大小)。以下是对 dist.get_world_size()
的详细解释:
函数签名
1
| torch.distributed.get_world_size(group=GroupMember.WORLD)
|
参数
group
(可选):指定要查询的进程组。如果未指定,默认为 GroupMember.WORLD
,即查询默认进程组。
返回值
用途
dist.get_world_size()
通常用于分布式训练中,以获取训练任务中参与的所有进程的数量。这在一些场景中非常有用,例如需要根据进程总数进行工作负载划分、数据分片等。
示例代码
以下是一个简单的示例,演示如何使用 dist.get_world_size()
:
1
2
3
4
5
6
7
8
9
10
11
12
| import torch.distributed as dist
def main():
# 初始化分布式环境
dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:23456", world_size=4, rank=0)
# 获取世界大小
world_size = dist.get_world_size()
print(f"World size: {world_size}")
if __name__ == "__main__":
main()
|
在这个示例中:
dist.init_process_group
用于初始化分布式进程组,指定了 backend
、init_method
、world_size
和 rank
。
dist.get_world_size()
获取初始化后的进程组的总进程数,并打印出来。
应用场景
- 数据并行:在数据并行训练中,获取世界大小用于计算每个进程应处理的数据量。例如,将数据集分成多个部分,每个进程处理其中一部分。
- 梯度聚合:在分布式训练中,了解世界大小对于梯度聚合过程是必要的,可以确保所有进程正确地参与梯度计算和更新。
- 资源分配:根据世界大小,可以动态调整资源分配策略,例如根据进程数分配 GPU 或其他硬件资源。
总之,dist.get_world_size()
是 PyTorch 分布式训练中一个重要的辅助函数,用于获取当前进程组的规模,帮助开发者编写更加灵活和高效的分布式训练代码。