配置文件字典操作
from addict import Dictconfig = Dict()
config.exp_name = 'CRNN'
config.train_options = {
# for train
'resume_from': '',# 继续训练地址
'third_party_name': '',# 加载paddle模型可选
'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",# 模型保存地址,log文件也保存在这里
'device': 'cuda:0',# 不建议修改
'epochs': 200,
'fine_tune_stage': ['backbone', 'neck', 'head'],
'print_interval': 10,# step为单位
'val_interval': 300,# step为单位
'ckpt_save_type': 'HighestAcc',# HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
'ckpt_save_epoch': 4,# epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
}
config.SEED = 927
config.optimizer = {
'type': 'Adam',
'lr': 0.001,
'weight_decay': 1e-4,
}
config.lr_scheduler = {
'type': 'StepLR',
'step_size': 60,
'gamma': 0.5
}
config.model = {
# backbone 可以设置'pretrained': False/True
'type': "RecModel",
# 'backbone': {"type": "ResNet", 'layers': 34},
# 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
# 'head': {"type": "CTC", 'n_class': 93},
# 'in_channels': 3,
'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
'head': {"type": "CTC", 'n_class': 93},
'in_channels': 3,
}
config.loss = {
'type': 'CTCLoss',
'blank_idx': 0,
}
# for dataset
config.dataset = {
'alphabet': r'torchocr/datasets/alphabets/digit.txt',
'train': {
'dataset': {
'type': 'RecTextLineDataset',
'file': r'path/train.txt',
'input_h': 32,
'mean': 0.5,
'std': 0.5,
'augmentation': False,
},
'loader': {
'type': 'DataLoader',# 使用torch dataloader只需要改为 DataLoader
'batch_size': 16,
'shuffle': True,
'num_workers': 1,
'collate_fn': {
'type': 'RecCollateFn',
'img_w': 120
}
}
},
'eval': {
'dataset': {
'type': 'RecTextLineDataset',
'file': r'path/eval.txt',
'input_h': 32,
'mean': 0.5,
'std': 0.5,
'augmentation': False,
},
'loader': {
'type': 'RecDataLoader',
'batch_size': 4,
'shuffle': False,
'num_workers': 1,
'collate_fn': {
'type': 'RecCollateFn',
'img_w': 120
}
}
}
}
# 转换为 Dict
for k, v in config.items():
if isinstance(v, dict):
config = Dict(v)pip install addictconfig.train_options
Out:
{'resume_from': '',
'third_party_name': '',
'checkpoint_save_dir': './output/CRNN/checkpoint',
'device': 'cuda:0',
'epochs': 200,
'fine_tune_stage': ['backbone', 'neck', 'head'],
'print_interval': 10,
'val_interval': 300,
'ckpt_save_type': 'HighestAcc',
'ckpt_save_epoch': 4}
config.train_options.val_interval
Out: 300
页:
[1]