Hello Mat

 找回密码
 立即注册
查看: 4070|回复: 0

配置文件字典操作

[复制链接]

1323

主题

1551

帖子

0

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22647
发表于 2021-5-10 22:39:11 | 显示全部楼层 |阅读模式
  1. from addict import Dict

  2. config = Dict()
  3. config.exp_name = 'CRNN'
  4. config.train_options = {
  5.     # for train
  6.     'resume_from': '',  # 继续训练地址
  7.     'third_party_name': '',  # 加载paddle模型可选
  8.     'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint",  # 模型保存地址,log文件也保存在这里
  9.     'device': 'cuda:0',  # 不建议修改
  10.     'epochs': 200,
  11.     'fine_tune_stage': ['backbone', 'neck', 'head'],
  12.     'print_interval': 10,  # step为单位
  13.     'val_interval': 300,  # step为单位
  14.     'ckpt_save_type': 'HighestAcc',  # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
  15.     'ckpt_save_epoch': 4,  # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
  16. }

  17. config.SEED = 927
  18. config.optimizer = {
  19.     'type': 'Adam',
  20.     'lr': 0.001,
  21.     'weight_decay': 1e-4,
  22. }

  23. config.lr_scheduler = {
  24.     'type': 'StepLR',
  25.     'step_size': 60,
  26.     'gamma': 0.5
  27. }
  28. config.model = {
  29.     # backbone 可以设置'pretrained': False/True
  30.     'type': "RecModel",
  31.     # 'backbone': {"type": "ResNet", 'layers': 34},
  32.     # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
  33.     # 'head': {"type": "CTC", 'n_class': 93},
  34.     # 'in_channels': 3,

  35.     'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
  36.     'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
  37.     'head': {"type": "CTC", 'n_class': 93},
  38.     'in_channels': 3,
  39. }

  40. config.loss = {
  41.     'type': 'CTCLoss',
  42.     'blank_idx': 0,
  43. }

  44. # for dataset
  45. config.dataset = {
  46.     'alphabet': r'torchocr/datasets/alphabets/digit.txt',
  47.     'train': {
  48.         'dataset': {
  49.             'type': 'RecTextLineDataset',
  50.             'file': r'path/train.txt',
  51.             'input_h': 32,
  52.             'mean': 0.5,
  53.             'std': 0.5,
  54.             'augmentation': False,
  55.         },
  56.         'loader': {
  57.             'type': 'DataLoader',  # 使用torch dataloader只需要改为 DataLoader
  58.             'batch_size': 16,
  59.             'shuffle': True,
  60.             'num_workers': 1,
  61.             'collate_fn': {
  62.                 'type': 'RecCollateFn',
  63.                 'img_w': 120
  64.             }
  65.         }
  66.     },
  67.     'eval': {
  68.         'dataset': {
  69.             'type': 'RecTextLineDataset',
  70.             'file': r'path/eval.txt',
  71.             'input_h': 32,
  72.             'mean': 0.5,
  73.             'std': 0.5,
  74.             'augmentation': False,
  75.         },
  76.         'loader': {
  77.             'type': 'RecDataLoader',
  78.             'batch_size': 4,
  79.             'shuffle': False,
  80.             'num_workers': 1,
  81.             'collate_fn': {
  82.                 'type': 'RecCollateFn',
  83.                 'img_w': 120
  84.             }
  85.         }
  86.     }
  87. }

  88. # 转换为 Dict
  89. for k, v in config.items():
  90.     if isinstance(v, dict):
  91.         config[k] = Dict(v)
复制代码
pip install addict
  1. config.train_options
  2. Out[3]:
  3. {'resume_from': '',
  4. 'third_party_name': '',
  5. 'checkpoint_save_dir': './output/CRNN/checkpoint',
  6. 'device': 'cuda:0',
  7. 'epochs': 200,
  8. 'fine_tune_stage': ['backbone', 'neck', 'head'],
  9. 'print_interval': 10,
  10. 'val_interval': 300,
  11. 'ckpt_save_type': 'HighestAcc',
  12. 'ckpt_save_epoch': 4}

  13. config.train_options.val_interval
  14. Out[4]: 300
复制代码







算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-11-23 00:28 , Processed in 0.187294 second(s), 21 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表