Halcom 发表于 2021-5-10 22:39:11

配置文件字典操作

from addict import Dict

config = Dict()
config.exp_name = 'CRNN'
config.train_options = {
    # for train
    'resume_from': '',# 继续训练地址
    'third_party_name': '',# 加载paddle模型可选
    'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",# 模型保存地址,log文件也保存在这里
    'device': 'cuda:0',# 不建议修改
    'epochs': 200,
    'fine_tune_stage': ['backbone', 'neck', 'head'],
    'print_interval': 10,# step为单位
    'val_interval': 300,# step为单位
    'ckpt_save_type': 'HighestAcc',# HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
    'ckpt_save_epoch': 4,# epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
}

config.SEED = 927
config.optimizer = {
    'type': 'Adam',
    'lr': 0.001,
    'weight_decay': 1e-4,
}

config.lr_scheduler = {
    'type': 'StepLR',
    'step_size': 60,
    'gamma': 0.5
}
config.model = {
    # backbone 可以设置'pretrained': False/True
    'type': "RecModel",
    # 'backbone': {"type": "ResNet", 'layers': 34},
    # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
    # 'head': {"type": "CTC", 'n_class': 93},
    # 'in_channels': 3,

    'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
    'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
    'head': {"type": "CTC", 'n_class': 93},
    'in_channels': 3,
}

config.loss = {
    'type': 'CTCLoss',
    'blank_idx': 0,
}

# for dataset
config.dataset = {
    'alphabet': r'torchocr/datasets/alphabets/digit.txt',
    'train': {
      'dataset': {
            'type': 'RecTextLineDataset',
            'file': r'path/train.txt',
            'input_h': 32,
            'mean': 0.5,
            'std': 0.5,
            'augmentation': False,
      },
      'loader': {
            'type': 'DataLoader',# 使用torch dataloader只需要改为 DataLoader
            'batch_size': 16,
            'shuffle': True,
            'num_workers': 1,
            'collate_fn': {
                'type': 'RecCollateFn',
                'img_w': 120
            }
      }
    },
    'eval': {
      'dataset': {
            'type': 'RecTextLineDataset',
            'file': r'path/eval.txt',
            'input_h': 32,
            'mean': 0.5,
            'std': 0.5,
            'augmentation': False,
      },
      'loader': {
            'type': 'RecDataLoader',
            'batch_size': 4,
            'shuffle': False,
            'num_workers': 1,
            'collate_fn': {
                'type': 'RecCollateFn',
                'img_w': 120
            }
      }
    }
}

# 转换为 Dict
for k, v in config.items():
    if isinstance(v, dict):
      config = Dict(v)pip install addictconfig.train_options
Out:
{'resume_from': '',
'third_party_name': '',
'checkpoint_save_dir': './output/CRNN/checkpoint',
'device': 'cuda:0',
'epochs': 200,
'fine_tune_stage': ['backbone', 'neck', 'head'],
'print_interval': 10,
'val_interval': 300,
'ckpt_save_type': 'HighestAcc',
'ckpt_save_epoch': 4}

config.train_options.val_interval
Out: 300






页: [1]
查看完整版本: 配置文件字典操作