Config, Argument and JSON Schema

2019.04.14
SF-Zhou

最近在重构现有的 PyTorch 训练框架,目标是组件化、低耦合、高复用,提高大家的研究效率。开发过程中思考了一些关于 Config 和 Command Line Argument 的问题,并且顺手开发了几个辅助小工具。本文可以看成是这几个工具的说明文档了。

1. Config File

训练框架里,Config File 的使用非常广泛。配置文件一般使用 YAML 和 JSON 格式,其中存储了关于实验的一些细节,比如用哪一种模型、多少的学习率以及使用哪种数据集。一般训练开始时直接将配置文件加载进来,读取为 dict,然后根据需要的 key 获得 value 即可,比如 config["model_type"]

然而这种做法存在一些潜在风险:

  1. 用户并不知道如何配置。如果没有详细的配置文档,用户是不知道可以设置哪些 key 以及对应的 value 的。即使在配置文件中设置了 model_type,你也无法确定 model_type 是否被读取和使用了;
  2. 获取默认值时的潜在风险,例如 config.get("model_type", "ResNet")。这种写法令人“害怕”,即使你的 key 写错了,代码依然正常执行;
  3. 编写代码过程中心智负担过重。即使开发者知道 config 有哪些 key,也会担心自己的 key 是不是写错了。

针对上述问题,笔者开发了一个小工具 json-schema-to-class。可以使用 pip3 安装(Python 3.6+):

pip3 install json-schema-to-class

该工具的目的就是通过 JSON Schema 来指导和约束开发者和用户:

  1. 用户可以参考 Schema 文件来编写配置文件;
  2. Schema 文件附带参数默认值及参数描述;
  3. 通过运行时读取 schema 文件并编译为 Python Class 代码,使得开发者在开发过程中获得足够的代码提示,规避了直接写 key 的烦恼。

直接参考一个现有的 demo 吧:torch-basic-models。该项目提供了两个基础的 PyTorch 模型代码:ResNet 和 MobileNetV2。两个模型具有不同配置项,其 Schema 放置于 torch_basic_models/schema。以 ResNet 为例,其Schema 定义如下:

{
  "$schema": "http://json-schema.org/draft-04/schema#",
  "title": "res_net_config",
  "description": "ResNet Model Config",
  "type": "object",
  "properties": {
    "layers_list": {
      "type": "array",
      "items": {
        "type": "integer"
      },
      "default": [3, 4, 6, 3]
    },
    "feature_dim": {
      "type": "integer",
      "default": 1000
    }
  },
  "additionalProperties": false
}

用户可以轻松的知道配置项目包括 layers_listfeature_dim,即使没有写明 description 也能猜到对应的含义。而开发者可以通过 torch_basic_models/configs/__init__.py 调用 json-schema-to-class 动态生成 torch_basic_models/configs/build/resnet_config.py

from typing import List


class ResNetConfig:
    def __init__(self, values: dict = None):
        values = values if values is not None else {}
        self.layers_list: List[int] = values.get("layers_list", [3, 4, 6, 3])
        self.feature_dim: int = values.get("feature_dim", 1000)

进而在编写代码过程直接访问 Python 对象来获取配置项,并且附带代码提示不至于犯小错误。有兴趣的同学可以详细研究下这个小项目。

Command Line Argument

Python 自带的 argparse 可以读取命令行参数,然而配置参数项和读取参数值时依然不够方便。以 PyTorch 官方的 ImageNet 训练项目为例:

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')

# ...

def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

同样令人“害怕😨”。笔者同样利用 JSON Schema 开发了 argparse-schema 来解决这个问题。该工具同样可以使用 pip3 直接安装,通过读取 Schema 文件来构建 argparse.ArgumentParser 对象。Schema 举例 tests/argument_config.json

{
  "$schema": "http://json-schema.org/draft-04/schema#",
  "title": "argument_config",
  "description": "Arg-parse Schema UnitTest",
  "type": "object",
  "properties": {
    "config": {
      "type": "string",
      "description": "path of config file"
    },
    "resume": {
      "type": "boolean",
      "description": "resume from checkpoint or not"
    },
    "scale": {
      "type": "number",
      "default": 1.0,
      "description": "scale of image"
    }
  },
  "required": [
    "config"
  ]
}

可以通过如下的 Python 代码调用:

# demo.py
import argparse_schema

print(argparse_schema.parse(schema='./tests/argument_config.json'))

执行:

python3 demo.py --config /path/to/config.py
#> {'config': '/path/to/config.py', 'resume': False, 'scale': 1.0}

也可以通过提供的 CLI 工具查看 Argument Schema 对应的命令行帮助:

argparse-schema --schema_path tests/argument_config.json
#> Show help for [tests/argument_config.json]:
#> usage: argparse-schema [-h] --config CONFIG [--resume] [--scale SCALE]
#>
#> Arg-parse Schema UnitTest
#>
#> optional arguments:
#>   -h, --help       show this help message and exit
#>   --config CONFIG  path of config file
#>   --resume         resume from checkpoint or not
#>   --scale SCALE    scale of image

json-schema-to-class 结合后,可以得到更为 Fancy 的结果,可以打开脑洞想一下。

References

  1. JSON Schema, https://json-schema.org