|
- from addict import Dict
- config = Dict()
- config.exp_name = 'CRNN'
- config.train_options = {
- # for train
- 'resume_from': '', # 继续训练地址
- 'third_party_name': '', # 加载paddle模型可选
- 'checkpoint_save_dir': f"./output/{config.exp_name}/checkpoint", # 模型保存地址,log文件也保存在这里
- 'device': 'cuda:0', # 不建议修改
- 'epochs': 200,
- 'fine_tune_stage': ['backbone', 'neck', 'head'],
- 'print_interval': 10, # step为单位
- 'val_interval': 300, # step为单位
- 'ckpt_save_type': 'HighestAcc', # HighestAcc:只保存最高准确率模型 ;FixedEpochStep:每隔ckpt_save_epoch个epoch保存一个
- 'ckpt_save_epoch': 4, # epoch为单位, 只有ckpt_save_type选择FixedEpochStep时,该参数才有效
- }
- config.SEED = 927
- config.optimizer = {
- 'type': 'Adam',
- 'lr': 0.001,
- 'weight_decay': 1e-4,
- }
- config.lr_scheduler = {
- 'type': 'StepLR',
- 'step_size': 60,
- 'gamma': 0.5
- }
- config.model = {
- # backbone 可以设置'pretrained': False/True
- 'type': "RecModel",
- # 'backbone': {"type": "ResNet", 'layers': 34},
- # 'neck': {"type": 'PPaddleRNN',"hidden_size": 256},
- # 'head': {"type": "CTC", 'n_class': 93},
- # 'in_channels': 3,
- 'backbone': {"type": "MobileNetV3", 'model_name': 'small'},
- 'neck': {"type": 'PPaddleRNN', "hidden_size": 48},
- 'head': {"type": "CTC", 'n_class': 93},
- 'in_channels': 3,
- }
- config.loss = {
- 'type': 'CTCLoss',
- 'blank_idx': 0,
- }
- # for dataset
- config.dataset = {
- 'alphabet': r'torchocr/datasets/alphabets/digit.txt',
- 'train': {
- 'dataset': {
- 'type': 'RecTextLineDataset',
- 'file': r'path/train.txt',
- 'input_h': 32,
- 'mean': 0.5,
- 'std': 0.5,
- 'augmentation': False,
- },
- 'loader': {
- 'type': 'DataLoader', # 使用torch dataloader只需要改为 DataLoader
- 'batch_size': 16,
- 'shuffle': True,
- 'num_workers': 1,
- 'collate_fn': {
- 'type': 'RecCollateFn',
- 'img_w': 120
- }
- }
- },
- 'eval': {
- 'dataset': {
- 'type': 'RecTextLineDataset',
- 'file': r'path/eval.txt',
- 'input_h': 32,
- 'mean': 0.5,
- 'std': 0.5,
- 'augmentation': False,
- },
- 'loader': {
- 'type': 'RecDataLoader',
- 'batch_size': 4,
- 'shuffle': False,
- 'num_workers': 1,
- 'collate_fn': {
- 'type': 'RecCollateFn',
- 'img_w': 120
- }
- }
- }
- }
- # 转换为 Dict
- for k, v in config.items():
- if isinstance(v, dict):
- config[k] = Dict(v)
复制代码 pip install addict- config.train_options
- Out[3]:
- {'resume_from': '',
- 'third_party_name': '',
- 'checkpoint_save_dir': './output/CRNN/checkpoint',
- 'device': 'cuda:0',
- 'epochs': 200,
- 'fine_tune_stage': ['backbone', 'neck', 'head'],
- 'print_interval': 10,
- 'val_interval': 300,
- 'ckpt_save_type': 'HighestAcc',
- 'ckpt_save_epoch': 4}
- config.train_options.val_interval
- Out[4]: 300
复制代码
|
|