Hello Mat

 找回密码
 立即注册
查看: 732|回复: 3

End2End-ASR-Pytorch

[复制链接]

11

主题

12

帖子

40

金钱

版主

Rank: 7Rank: 7Rank: 7

积分
80
发表于 2023-7-28 15:52:12 | 显示全部楼层 |阅读模式
End2End-ASR-Pytorch
B站视频:
【1】ASR调试视频https://www.bilibili.com/video/B ... 24aa2e8ea036f9d24f4
【2】ASR安装torchaudio
【3】ASR端到端代码讲解裁剪 https://www.bilibili.com/video/B ... 3af3339cdeadc088281
【4】ASR训练AI模型视频
【5】Step5_ASR代码简化处理讲解
  1. import sys
  2. import os
  3. import json
  4. import math
  5. import numpy as np
  6. import random
  7. import unicodedata
  8. import scipy.signal
  9. import subprocess
  10. from tempfile import NamedTemporaryFile
  11. import Levenshtein as Lev

  12. import torch
  13. import torch.nn as nn
  14. import torch.nn.functional as F
  15. from torch.autograd import Variable
  16. from torch.utils.data import DataLoader
  17. from torch.utils.data import Dataset
  18. from torch.utils.data.sampler import Sampler

  19. import librosa
  20. import torchaudio

  21. windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, 'bartlett': scipy.signal.bartlett}
  22. PAD_TOKEN = 0
  23. SOS_TOKEN = 1
  24. EOS_TOKEN = 2
  25. PAD_CHAR = "¶"
  26. SOS_CHAR = "§"
  27. EOS_CHAR = "¤"

  28. def calculate_cer(s1, s2):
  29.     """
  30.     Computes the Character Error Rate, defined as the edit distance.

  31.     Arguments:
  32.         s1 (string): space-separated sentence (hyp)
  33.         s2 (string): space-separated sentence (gold)
  34.     """
  35.     return Lev.distance(s1, s2)

  36. def calculate_wer(s1, s2):
  37.     """
  38.     Computes the Word Error Rate, defined as the edit distance between the
  39.     two provided sentences after tokenizing to words.
  40.     Arguments:
  41.         s1 (string): space-separated sentence
  42.         s2 (string): space-separated sentence
  43.     """
  44.     # build mapping of words to integers
  45.     b = set(s1.split() + s2.split())
  46.     word2char = dict(zip(b, range(len(b))))

  47.     # map the words to a char array (Levenshtein packages only accepts
  48.     # strings)
  49.     w1 = [chr(word2char[w]) for w in s1.split()]
  50.     w2 = [chr(word2char[w]) for w in s2.split()]

  51.     return Lev.distance(''.join(w1), ''.join(w2))

  52. def calculate_metrics(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
  53.     """
  54.     Calculate metrics
  55.     args:
  56.         pred: B x T x C
  57.         gold: B x T
  58.         input_lengths: B (for CTC)
  59.         target_lengths: B (for CTC)
  60.     """
  61.     loss = calculate_loss(pred, gold, input_lengths, target_lengths, smoothing, loss_type)
  62.     # if loss_type == "ce":
  63.     pred = pred.view(-1, pred.size(2)) # (B*T) x C
  64.     gold = gold.contiguous().view(-1) # (B*T)
  65.     pred = pred.max(1)[1]
  66.     non_pad_mask = gold.ne(PAD_TOKEN)
  67.     num_correct = pred.eq(gold)
  68.     num_correct = num_correct.masked_select(non_pad_mask).sum().item()
  69.     return loss, num_correct

  70. def calculate_loss(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
  71.     """
  72.     Calculate loss
  73.     args:
  74.         pred: B x T x C
  75.         gold: B x T
  76.         input_lengths: B (for CTC)
  77.         target_lengths: B (for CTC)
  78.         smoothing:
  79.         type: ce|ctc (ctc => pytorch 1.0.0 or later)
  80.         input_lengths: B (only for ctc)
  81.         target_lengths: B (only for ctc)
  82.     """
  83.     pred = pred.view(-1, pred.size(2)) # (B*T) x C
  84.     gold = gold.contiguous().view(-1) # (B*T)
  85.     loss = F.cross_entropy(pred, gold, ignore_index = PAD_TOKEN, reduction="mean")
  86.     return loss

  87. def train(model, train_loader, train_sampler, valid_loader_list, start_epoch, num_epochs, label2id, id2label, last_metrics=None):
  88.     """
  89.     Training
  90.     args:
  91.         model: Model object
  92.         train_loader: DataLoader object of the training set
  93.         valid_loader_list: a list of Validation DataLoader objects
  94.         opt: Optimizer object
  95.         start_epoch: start epoch (> 0 if you resume the process)
  96.         num_epochs: last epoch
  97.         last_metrics: (if resume)
  98.     """
  99.     best_valid_loss = 1000000000 if last_metrics is None else last_metrics['valid_loss']

  100.     optimizer = torch.optim.Adam([dict(params = model.parameters(), lr = 0.001)])
  101.     for epoch in range(start_epoch, num_epochs):
  102.         sys.stdout.flush()
  103.         total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0

  104.         model.train()
  105.         model = model.cuda(0)
  106.         for i, (data) in enumerate(train_loader, start = 0):
  107.             src, tgt, src_percentages, src_lengths, tgt_lengths = data
  108.             src = src.cuda()
  109.             tgt = tgt.cuda()

  110.             optimizer.zero_grad()
  111.             # opt.zero_grad()

  112.             pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)

  113.             try: # handle case for CTC
  114.                 strs_gold, strs_hyps = [], []
  115.                 for ut_gold in gold_seq:
  116.                     str_gold = ""
  117.                     for x in ut_gold:
  118.                         if int(x) == PAD_TOKEN:
  119.                             break
  120.                         str_gold = str_gold + id2label[int(x)]
  121.                     strs_gold.append(str_gold)
  122.                 for ut_hyp in hyp_seq:
  123.                     str_hyp = ""
  124.                     for x in ut_hyp:
  125.                         if int(x) == PAD_TOKEN:
  126.                             break
  127.                         str_hyp = str_hyp + id2label[int(x)]
  128.                     strs_hyps.append(str_hyp)
  129.             except Exception as e:
  130.                 print(e)
  131.                 # logging.info("NaN predictions")
  132.                 continue

  133.             seq_length = pred.size(1)
  134.             sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)

  135.             loss, num_correct = calculate_metrics(
  136.                 pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=0, loss_type='ce')

  137.             if loss.item() == float('Inf'):
  138.                 # logging.info("Found infinity loss, masking")
  139.                 loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
  140.                 continue

  141.             for j in range(len(strs_hyps)):
  142.                 strs_hyps[j] = strs_hyps[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
  143.                 strs_gold[j] = strs_gold[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
  144.                 cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
  145.                 wer = calculate_wer(strs_hyps[j], strs_gold[j])
  146.                 total_cer += cer
  147.                 total_wer += wer
  148.                 total_char += len(strs_gold[j].replace(' ', ''))
  149.                 total_word += len(strs_gold[j].split(" "))

  150.             loss.backward()
  151.             optimizer.step()
  152.             # opt.step()
  153.             total_loss += loss.item()
  154.             print("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% ".format( (epoch+1), total_loss/(i+1), total_cer*100/total_char ) )

  155.         # evaluate
  156.         print("")
  157.         model.eval()
  158.         for ind in range(len(valid_loader_list)):
  159.             valid_loader = valid_loader_list[ind]

  160.             total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
  161.             # valid_pbar = tqdm(iter(valid_loader), leave=True, total=len(valid_loader))
  162.             # for i, (data) in enumerate(valid_pbar):
  163.             for i, (data) in enumerate(valid_loader):   
  164.                 src, tgt, src_percentages, src_lengths, tgt_lengths = data
  165.                 src = src.cuda()
  166.                 tgt = tgt.cuda()

  167.                 with torch.no_grad():
  168.                     pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)

  169.                 seq_length = pred.size(1)
  170.                 sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)

  171.                 loss, num_correct = calculate_metrics(
  172.                     pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=0, loss_type='ce')

  173.                 if loss.item() == float('Inf'):
  174.                     # logging.info("Found infinity loss, masking")
  175.                     loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
  176.                     continue

  177.                 try: # handle case for CTC
  178.                     strs_gold, strs_hyps = [], []
  179.                     for ut_gold in gold_seq:
  180.                         str_gold = ""
  181.                         for x in ut_gold:
  182.                             if int(x) == PAD_TOKEN:
  183.                                 break
  184.                             str_gold = str_gold + id2label[int(x)]
  185.                         strs_gold.append(str_gold)
  186.                     for ut_hyp in hyp_seq:
  187.                         str_hyp = ""
  188.                         for x in ut_hyp:
  189.                             if int(x) == PAD_TOKEN:
  190.                                 break
  191.                             str_hyp = str_hyp + id2label[int(x)]
  192.                         strs_hyps.append(str_hyp)
  193.                 except Exception as e:
  194.                     print(e)
  195.                     # logging.info("NaN predictions")
  196.                     continue

  197.                 for j in range(len(strs_hyps)):
  198.                     strs_hyps[j] = strs_hyps[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
  199.                     strs_gold[j] = strs_gold[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
  200.                     cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
  201.                     wer = calculate_wer(strs_hyps[j], strs_gold[j])
  202.                     total_valid_cer += cer
  203.                     total_valid_wer += wer
  204.                     total_valid_char += len(strs_gold[j].replace(' ', ''))
  205.                     total_valid_word += len(strs_gold[j].split(" "))

  206.                 total_valid_loss += loss.item()
  207.                 print("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(ind, total_valid_loss/(i+1), total_valid_cer*100/total_valid_char))

  208.         if epoch % 5 == 0:
  209.             save_model(model, (epoch+1), metrics, label2id, id2label, best_model=False)

  210.         # save the best model
  211.         if best_valid_loss > total_valid_loss/len(valid_loader):
  212.             best_valid_loss = total_valid_loss/len(valid_loader)
  213.             save_model(model, (epoch+1), metrics, label2id, id2label, best_model=True)

  214. class Transformer(nn.Module):
  215.     """
  216.     Transformer class
  217.     args:
  218.         encoder: Encoder object
  219.         decoder: Decoder object
  220.     """

  221.     def __init__(self, encoder, decoder, feat_extractor='vgg_cnn'):
  222.         super(Transformer, self).__init__()
  223.         self.encoder = encoder
  224.         self.decoder = decoder
  225.         self.id2label = decoder.id2label
  226.         self.feat_extractor = feat_extractor
  227.         
  228.         for p in self.parameters():
  229.             if p.dim() > 1:
  230.                 nn.init.xavier_uniform_(p)

  231.     def forward(self, padded_input, input_lengths, padded_target, verbose=False):
  232.         """
  233.         args:
  234.             padded_input: B x 1 (channel for spectrogram=1) x (freq) x T
  235.             padded_input: B x T x D
  236.             input_lengths: B
  237.             padded_target: B x T
  238.         output:
  239.             pred: B x T x vocab
  240.             gold: B x T
  241.         """
  242.         if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
  243.             padded_input = self.conv(padded_input)

  244.         # Reshaping features
  245.         sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
  246.         padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
  247.         padded_input = padded_input.transpose(1, 2).contiguous()  # BxTxH

  248.         encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
  249.         pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
  250.         hyp_best_scores, hyp_best_ids = torch.topk(pred, 1, dim=2)

  251.         hyp_seq = hyp_best_ids.squeeze(2)
  252.         gold_seq = gold

  253.         return pred, gold, hyp_seq, gold_seq

  254.     def evaluate(self, padded_input, input_lengths, padded_target):
  255.         """
  256.         args:
  257.             padded_input: B x T x D
  258.             input_lengths: B
  259.             padded_target: B x T
  260.         output:
  261.             batch_ids_nbest_hyps: list of nbest id
  262.             batch_strs_nbest_hyps: list of nbest str
  263.             batch_strs_gold: list of gold str
  264.         """
  265.         if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
  266.             padded_input = self.conv(padded_input)

  267.         # Reshaping features
  268.         sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
  269.         padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
  270.         padded_input = padded_input.transpose(1, 2).contiguous()  # BxTxH

  271.         encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
  272.         hyp, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
  273.         hyp_best_scores, hyp_best_ids = torch.topk(hyp, 1, dim=2)
  274.         
  275.         strs_gold = ["".join([self.id2label[int(x)] for x in gold_seq]) for gold_seq in gold]

  276.         strs_hyps = self.decoder.greedy_search(encoder_padded_outputs)
  277.         
  278.         return _, strs_hyps, strs_gold
  279.    
  280. class Encoder(nn.Module):
  281.     """
  282.     Encoder Transformer class
  283.     """
  284.     def __init__(self, num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dropout=0.1, src_max_length=2500):
  285.         super(Encoder, self).__init__()

  286.         self.dim_input = dim_input
  287.         self.num_layers = num_layers
  288.         self.num_heads = num_heads

  289.         self.dim_model = dim_model
  290.         self.dim_key = dim_key
  291.         self.dim_value = dim_value
  292.         self.dim_inner = dim_inner

  293.         self.src_max_length = src_max_length

  294.         self.dropout = nn.Dropout(dropout)
  295.         self.dropout_rate = dropout

  296.         self.input_linear = nn.Linear(dim_input, dim_model)
  297.         self.layer_norm_input = nn.LayerNorm(dim_model)
  298.         self.positional_encoding = PositionalEncoding(dim_model, src_max_length)

  299.         self.layers = nn.ModuleList([EncoderLayer(num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=dropout) for _ in range(num_layers)])

  300.     def forward(self, padded_input, input_lengths):
  301.         """
  302.         args:
  303.             padded_input: B x T x D
  304.             input_lengths: B
  305.         return:
  306.             output: B x T x H
  307.         """
  308.         encoder_self_attn_list = []

  309.         # Prepare masks
  310.         non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)  # B x T x D
  311.         seq_len = padded_input.size(1)
  312.         self_attn_mask = get_attn_pad_mask(padded_input, input_lengths, seq_len)  # B x T x T

  313.         encoder_output = self.layer_norm_input(self.input_linear(
  314.             padded_input)) + self.positional_encoding(padded_input)

  315.         for layer in self.layers:
  316.             encoder_output, self_attn = layer(
  317.                 encoder_output, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask)
  318.             encoder_self_attn_list += [self_attn]

  319.         return encoder_output, encoder_self_attn_list


  320. class EncoderLayer(nn.Module):
  321.     """
  322.     Encoder Layer Transformer class
  323.     """

  324.     def __init__(self, num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=0.1):
  325.         super(EncoderLayer, self).__init__()
  326.         self.self_attn = MultiHeadAttention(
  327.             num_heads, dim_model, dim_key, dim_value, dropout=dropout)
  328.         self.pos_ffn = PositionwiseFeedForwardWithConv(
  329.             dim_model, dim_inner, dropout=dropout)

  330.     def forward(self, enc_input, non_pad_mask=None, self_attn_mask=None):
  331.         enc_output, self_attn = self.self_attn(
  332.             enc_input, enc_input, enc_input, mask=self_attn_mask)
  333.         enc_output *= non_pad_mask

  334.         enc_output = self.pos_ffn(enc_output)
  335.         enc_output *= non_pad_mask

  336.         return enc_output, self_attn


  337. class Decoder(nn.Module):
  338.     """
  339.     Decoder Layer Transformer class
  340.     """

  341.     def __init__(self, id2label, num_src_vocab, num_trg_vocab, num_layers, num_heads, dim_emb, dim_model, dim_inner, dim_key, dim_value, dropout=0.1, trg_max_length=1000, emb_trg_sharing=False):
  342.         super(Decoder, self).__init__()
  343.         self.sos_id = SOS_TOKEN
  344.         self.eos_id = EOS_TOKEN

  345.         self.id2label = id2label

  346.         self.num_src_vocab = num_src_vocab
  347.         self.num_trg_vocab = num_trg_vocab
  348.         self.num_layers = num_layers
  349.         self.num_heads = num_heads

  350.         self.dim_emb = dim_emb
  351.         self.dim_model = dim_model
  352.         self.dim_inner = dim_inner
  353.         self.dim_key = dim_key
  354.         self.dim_value = dim_value

  355.         self.dropout_rate = dropout
  356.         self.emb_trg_sharing = emb_trg_sharing

  357.         self.trg_max_length = trg_max_length

  358.         self.trg_embedding = nn.Embedding(num_trg_vocab, dim_emb, padding_idx = PAD_TOKEN)
  359.         self.positional_encoding = PositionalEncoding(
  360.             dim_model, max_length=trg_max_length)
  361.         self.dropout = nn.Dropout(dropout)

  362.         self.layers = nn.ModuleList([
  363.             DecoderLayer(dim_model, dim_inner, num_heads,
  364.                          dim_key, dim_value, dropout=dropout)
  365.             for _ in range(num_layers)
  366.         ])

  367.         self.output_linear = nn.Linear(dim_model, num_trg_vocab, bias=False)
  368.         nn.init.xavier_normal_(self.output_linear.weight)

  369.         if emb_trg_sharing:
  370.             self.output_linear.weight = self.trg_embedding.weight
  371.             self.x_logit_scale = (dim_model ** -0.5)
  372.         else:
  373.             self.x_logit_scale = 1.0

  374.     def preprocess(self, padded_input):
  375.         """
  376.         Add SOS TOKEN and EOS TOKEN into padded_input
  377.         """
  378.         seq = [y[y != PAD_TOKEN] for y in padded_input]
  379.         eos = seq[0].new([self.eos_id])
  380.         sos = seq[0].new([self.sos_id])
  381.         seq_in = [torch.cat([sos, y], dim=0) for y in seq]
  382.         seq_out = [torch.cat([y, eos], dim=0) for y in seq]
  383.         seq_in_pad = pad_list(seq_in, self.eos_id)
  384.         seq_out_pad = pad_list(seq_out, PAD_TOKEN)
  385.         assert seq_in_pad.size() == seq_out_pad.size()
  386.         return seq_in_pad, seq_out_pad

  387.     def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths):
  388.         """
  389.         args:
  390.             padded_input: B x T
  391.             encoder_padded_outputs: B x T x H
  392.             encoder_input_lengths: B
  393.         returns:
  394.             pred: B x T x vocab
  395.             gold: B x T
  396.         """
  397.         decoder_self_attn_list, decoder_encoder_attn_list = [], []
  398.         seq_in_pad, seq_out_pad = self.preprocess(padded_input)

  399.         # Prepare masks
  400.         non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx = EOS_TOKEN)
  401.         self_attn_mask_subseq = get_subsequent_mask(seq_in_pad)
  402.         self_attn_mask_keypad = get_attn_key_pad_mask(seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx = EOS_TOKEN)
  403.         self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)

  404.         output_length = seq_in_pad.size(1)
  405.         dec_enc_attn_mask = get_attn_pad_mask(
  406.             encoder_padded_outputs, encoder_input_lengths, output_length)

  407.         decoder_output = self.dropout(self.trg_embedding(
  408.             seq_in_pad) * self.x_logit_scale + self.positional_encoding(seq_in_pad))

  409.         for layer in self.layers:
  410.             decoder_output, decoder_self_attn, decoder_enc_attn = layer(
  411.                 decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask)

  412.             decoder_self_attn_list += [decoder_self_attn]
  413.             decoder_encoder_attn_list += [decoder_enc_attn]

  414.         seq_logit = self.output_linear(decoder_output)
  415.         pred, gold = seq_logit, seq_out_pad

  416.         return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list

  417.     def post_process_hyp(self, hyp):
  418.         """
  419.         args:
  420.             hyp: list of hypothesis
  421.         output:
  422.             list of hypothesis (string)>
  423.         """
  424.         return "".join([self.id2label[int(x)] for x in hyp['yseq'][1:]])

  425.     def greedy_search(self, encoder_padded_outputs, beam_width=2, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1):
  426.         """
  427.         Greedy search, decode 1-best utterance
  428.         args:
  429.             encoder_padded_outputs: B x T x H
  430.         output:
  431.             batch_ids_nbest_hyps: list of nbest in ids (size B)
  432.             batch_strs_nbest_hyps: list of nbest in strings (size B)
  433.         """
  434.         # max_seq_len = self.trg_max_length
  435.         ys = torch.ones(encoder_padded_outputs.size(0),1).fill_(SOS_TOKEN).long() # batch_size x 1
  436.         if True:
  437.             ys = ys.cuda()

  438.         decoded_words = []
  439.         for t in range(300):
  440.         # for t in range(max_seq_len):
  441.             # print(t)
  442.             # Prepare masks
  443.             non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # batch_size x t x 1
  444.             self_attn_mask = get_subsequent_mask(ys) # batch_size x t x t

  445.             decoder_output = self.dropout(self.trg_embedding(ys) * self.x_logit_scale
  446.                                         + self.positional_encoding(ys))

  447.             for layer in self.layers:
  448.                 decoder_output, _, _ = layer(
  449.                     decoder_output, encoder_padded_outputs,
  450.                     non_pad_mask=non_pad_mask,
  451.                     self_attn_mask=self_attn_mask,
  452.                     dec_enc_attn_mask=None
  453.                 )

  454.             prob = self.output_linear(decoder_output) # batch_size x t x label_size

  455.             _, next_word = torch.max(prob[:, -1], dim=1)
  456.             decoded_words.append([EOS_CHAR if ni.item() == EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)])
  457.             next_word = next_word.unsqueeze(-1)

  458.             if True:
  459.                 ys = torch.cat([ys, next_word.cuda()], dim=1)
  460.                 ys = ys.cuda()
  461.             else:
  462.                 ys = torch.cat([ys, next_word], dim=1)

  463.         sent = []
  464.         for _, row in enumerate(np.transpose(decoded_words)):
  465.             st = ''
  466.             for e in row:
  467.                 if e == EOS_CHAR:
  468.                     break
  469.                 else:
  470.                     st += e
  471.             sent.append(st)
  472.         return sent

  473. class DecoderLayer(nn.Module):
  474.     """
  475.     Decoder Transformer class
  476.     """

  477.     def __init__(self, dim_model, dim_inner, num_heads, dim_key, dim_value, dropout=0.1):
  478.         super(DecoderLayer, self).__init__()
  479.         self.self_attn = MultiHeadAttention(
  480.             num_heads, dim_model, dim_key, dim_value, dropout=dropout)
  481.         self.encoder_attn = MultiHeadAttention(
  482.             num_heads, dim_model, dim_key, dim_value, dropout=dropout)
  483.         self.pos_ffn = PositionwiseFeedForwardWithConv(
  484.             dim_model, dim_inner, dropout=dropout)

  485.     def forward(self, decoder_input, encoder_output, non_pad_mask=None, self_attn_mask=None, dec_enc_attn_mask=None):
  486.         decoder_output, decoder_self_attn = self.self_attn(
  487.             decoder_input, decoder_input, decoder_input, mask=self_attn_mask)
  488.         decoder_output *= non_pad_mask

  489.         decoder_output, decoder_encoder_attn = self.encoder_attn(
  490.             decoder_output, encoder_output, encoder_output, mask=dec_enc_attn_mask)
  491.         decoder_output *= non_pad_mask

  492.         decoder_output = self.pos_ffn(decoder_output)
  493.         decoder_output *= non_pad_mask

  494.         return decoder_output, decoder_self_attn, decoder_encoder_attn        

  495. """
  496. General purpose functions
  497. """
  498. def pad_list(xs, pad_value):
  499.     # From: espnet/src/nets/e2e_asr_th.py: pad_list()
  500.     n_batch = len(xs)
  501.     # max_len = max(x.size(0) for x in xs)
  502.     # tgt_max_len=1000
  503.     max_len = 1000
  504.     pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
  505.     for i in range(n_batch):
  506.         pad[i, :xs[i].size(0)] = xs[i]
  507.     return pad

  508. """
  509. Transformer common layers
  510. """
  511. def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None):
  512.     """
  513.     padding position is set to 0, either use input_lengths or pad_idx
  514.     """
  515.     assert input_lengths is not None or pad_idx is not None
  516.     if input_lengths is not None:
  517.         # padded_input: N x T x ..
  518.         N = padded_input.size(0)
  519.         non_pad_mask = padded_input.new_ones(padded_input.size()[:-1])  # B x T
  520.         for i in range(N):
  521.             non_pad_mask[i, input_lengths[i]:] = 0
  522.     if pad_idx is not None:
  523.         # padded_input: N x T
  524.         assert padded_input.dim() == 2
  525.         non_pad_mask = padded_input.ne(pad_idx).float()
  526.     # unsqueeze(-1) for broadcast
  527.     return non_pad_mask.unsqueeze(-1)

  528. def get_attn_key_pad_mask(seq_k, seq_q, pad_idx):
  529.     """
  530.     For masking out the padding part of key sequence.
  531.     """
  532.     # Expand to fit the shape of key query attention matrix.
  533.     len_q = seq_q.size(1)
  534.     padding_mask = seq_k.eq(pad_idx)
  535.     padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # B x T_Q x T_K

  536.     return padding_mask

  537. def get_attn_pad_mask(padded_input, input_lengths, expand_length):
  538.     """mask position is set to 1"""
  539.     # N x Ti x 1
  540.     non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
  541.     # N x Ti, lt(1) like not operation
  542.     pad_mask = non_pad_mask.squeeze(-1).lt(1)
  543.     attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
  544.     return attn_mask

  545. def get_subsequent_mask(seq):
  546.     ''' For masking out the subsequent info. '''

  547.     sz_b, len_s = seq.size()
  548.     subsequent_mask = torch.triu(
  549.         torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
  550.     subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

  551.     return subsequent_mask

  552. class PositionalEncoding(nn.Module):
  553.     """
  554.     Positional Encoding class
  555.     """
  556.     def __init__(self, dim_model, max_length=2000):
  557.         super(PositionalEncoding, self).__init__()

  558.         pe = torch.zeros(max_length, dim_model, requires_grad=False)
  559.         position = torch.arange(0, max_length).unsqueeze(1).float()
  560.         exp_term = torch.exp(torch.arange(0, dim_model, 2).float() * -(math.log(10000.0) / dim_model))
  561.         pe[:, 0::2] = torch.sin(position * exp_term) # take the odd (jump by 2)
  562.         pe[:, 1::2] = torch.cos(position * exp_term) # take the even (jump by 2)
  563.         pe = pe.unsqueeze(0)
  564.         self.register_buffer('pe', pe)

  565.     def forward(self, input):
  566.         """
  567.         args:
  568.             input: B x T x D
  569.         output:
  570.             tensor: B x T
  571.         """
  572.         return self.pe[:, :input.size(1)]

  573. class PositionwiseFeedForward(nn.Module):
  574.     """
  575.     Position-wise Feedforward Layer class
  576.     FFN(x) = max(0, xW1 + b1) W2+ b2
  577.     """
  578.     def __init__(self, dim_model, dim_ff, dropout=0.1):
  579.         super(PositionwiseFeedForward, self).__init__()
  580.         self.linear_1 = nn.Linear(dim_model, dim_ff)
  581.         self.linear_2 = nn.Linear(dim_ff, dim_model)
  582.         self.dropout = nn.Dropout(dropout)
  583.         self.layer_norm = nn.LayerNorm(dim_model)

  584.     def forward(self, x):
  585.         """
  586.         args:
  587.             x: tensor
  588.         output:
  589.             y: tensor
  590.         """
  591.         residual = x
  592.         output = self.dropout(self.linear_2(F.relu(self.linear_1(x))))
  593.         output = self.layer_norm(output + residual)
  594.         return output

  595. class PositionwiseFeedForwardWithConv(nn.Module):
  596.     """
  597.     Position-wise Feedforward Layer Implementation with Convolution class
  598.     """
  599.     def __init__(self, dim_model, dim_hidden, dropout=0.1):
  600.         super(PositionwiseFeedForwardWithConv, self).__init__()
  601.         self.conv_1 = nn.Conv1d(dim_model, dim_hidden, 1)
  602.         self.conv_2 = nn.Conv1d(dim_hidden, dim_model, 1)
  603.         self.dropout = nn.Dropout(dropout)
  604.         self.layer_norm = nn.LayerNorm(dim_model)

  605.     def forward(self, x):
  606.         residual = x
  607.         output = x.transpose(1, 2)
  608.         output = self.conv_2(F.relu(self.conv_1(output)))
  609.         output = output.transpose(1, 2)
  610.         output = self.dropout(output)
  611.         output = self.layer_norm(output + residual)
  612.         return output

  613. class MultiHeadAttention(nn.Module):
  614.     def __init__(self, num_heads, dim_model, dim_key, dim_value, dropout=0.1):
  615.         super(MultiHeadAttention, self).__init__()

  616.         self.num_heads = num_heads

  617.         self.dim_model = dim_model
  618.         self.dim_key = dim_key
  619.         self.dim_value = dim_value

  620.         self.query_linear = nn.Linear(dim_model, num_heads * dim_key)
  621.         self.key_linear = nn.Linear(dim_model, num_heads * dim_key)
  622.         self.value_linear = nn.Linear(dim_model, num_heads * dim_value)

  623.         nn.init.normal_(self.query_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
  624.         nn.init.normal_(self.key_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
  625.         nn.init.normal_(self.value_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_value)))

  626.         self.attention = ScaledDotProductAttention(temperature=np.power(dim_key, 0.5), attn_dropout=dropout)
  627.         self.layer_norm = nn.LayerNorm(dim_model)

  628.         self.output_linear = nn.Linear(num_heads * dim_value, dim_model)
  629.         nn.init.xavier_normal_(self.output_linear.weight)
  630.         
  631.         self.dropout = nn.Dropout(dropout)

  632.     def forward(self, query, key, value, mask=None):
  633.         """
  634.         query: B x T_Q x H, key: B x T_K x H, value: B x T_V x H
  635.         mask: B x T x T (attention mask)
  636.         """
  637.         batch_size, len_query, _ = query.size()
  638.         batch_size, len_key, _ = key.size()
  639.         batch_size, len_value, _ = value.size()

  640.         residual = query

  641.         query = self.query_linear(query).view(batch_size, len_query, self.num_heads, self.dim_key) # B x T_Q x num_heads x H_K
  642.         key = self.key_linear(key).view(batch_size, len_key, self.num_heads, self.dim_key) # B x T_K x num_heads x H_K
  643.         value = self.value_linear(value).view(batch_size, len_value, self.num_heads, self.dim_value) # B x T_V x num_heads x H_V

  644.         query = query.permute(2, 0, 1, 3).contiguous().view(-1, len_query, self.dim_key) # (num_heads * B) x T_Q x H_K
  645.         key = key.permute(2, 0, 1, 3).contiguous().view(-1, len_key, self.dim_key) # (num_heads * B) x T_K x H_K
  646.         value = value.permute(2, 0, 1, 3).contiguous().view(-1, len_value, self.dim_value) # (num_heads * B) x T_V x H_V

  647.         if mask is not None:
  648.             mask = mask.repeat(self.num_heads, 1, 1) # (B * num_head) x T x T
  649.         
  650.         output, attn = self.attention(query, key, value, mask=mask)

  651.         output = output.view(self.num_heads, batch_size, len_query, self.dim_value) # num_heads x B x T_Q x H_V
  652.         output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, len_query, -1) # B x T_Q x (num_heads * H_V)

  653.         output = self.dropout(self.output_linear(output)) # B x T_Q x H_O
  654.         output = self.layer_norm(output + residual)

  655.         return output, attn

  656. class ScaledDotProductAttention(nn.Module):
  657.     ''' Scaled Dot-Product Attention '''

  658.     def __init__(self, temperature, attn_dropout=0.1):
  659.         super().__init__()
  660.         self.temperature = temperature
  661.         self.dropout = nn.Dropout(attn_dropout)
  662.         self.softmax = nn.Softmax(dim=2)

  663.     def forward(self, q, k, v, mask=None):
  664.         """

  665.         """
  666.         attn = torch.bmm(q, k.transpose(1, 2))
  667.         attn = attn / self.temperature

  668.         if mask is not None:
  669.             attn = attn.masked_fill(mask, -np.inf)

  670.         attn = self.softmax(attn)
  671.         attn = self.dropout(attn)
  672.         output = torch.bmm(attn, v)

  673.         return output, attn

  674. def save_model(model, epoch, metrics, label2id, id2label, best_model=False):
  675.     """
  676.     Saving model, TODO adding history
  677.     """
  678.     save_folder = 'models/'
  679.     name = 'model'
  680.     if best_model:
  681.         save_path = "{}/{}/best_model.th".format(save_folder, name)
  682.     else:
  683.         save_path = "{}/{}/epoch_{}.th".format(save_folder, name, epoch)

  684.     if not os.path.exists(save_folder + "/" + name):
  685.         os.makedirs(save_folder + "/" + name)

  686.     print("SAVE MODEL to", save_path)
  687.     args = {
  688.         'label2id': label2id,
  689.         'id2label': id2label,
  690.         'epoch': epoch,
  691.         'model_state_dict': model.state_dict(),
  692.         'metrics': metrics
  693.     }
  694.     torch.save(args, save_path)

  695. def init_transformer_model2(num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dim_emb, src_max_len, tgt_max_len, dropout, emb_trg_sharing, feat_extractor, label2id, id2label):
  696.     """
  697.     Initiate a new transformer object
  698.     """
  699.     encoder = Encoder(num_layers, num_heads=num_heads, dim_model=dim_model, dim_key=dim_key,
  700.                       dim_value=dim_value, dim_input=dim_input, dim_inner=dim_inner, src_max_length=src_max_len, dropout=dropout)
  701.     decoder = Decoder(id2label, num_src_vocab=len(label2id), num_trg_vocab=len(label2id), num_layers=num_layers, num_heads=num_heads,
  702.                       dim_emb=dim_emb, dim_model=dim_model, dim_inner=dim_inner, dim_key=dim_key, dim_value=dim_value, trg_max_length=tgt_max_len, dropout=dropout, emb_trg_sharing=emb_trg_sharing)
  703.     model = Transformer(encoder, decoder, feat_extractor=feat_extractor)

  704.     return model

  705. def load_audio(path):
  706.     # sound, _ = torchaudio.load(path, normalization=True)
  707.     sound, _ = torchaudio.load(path)
  708.     sound = sound.numpy().T
  709.     if len(sound.shape) > 1:
  710.         if sound.shape[1] == 1:
  711.             sound = sound.squeeze()
  712.         else:
  713.             sound = sound.mean(axis=1)  # multiple channels, average
  714.     return sound

  715. def get_audio_length(path):
  716.     output = subprocess.check_output(
  717.         ['soxi -D "%s"' % path.strip()], shell=True)
  718.     return float(output)

  719. def audio_with_sox(path, sample_rate, start_time, end_time):
  720.     """
  721.     crop and resample the recording with sox and loads it.
  722.     """
  723.     with NamedTemporaryFile(suffix=".wav") as tar_file:
  724.         tar_filename = tar_file.name
  725.         sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate, tar_filename, start_time, end_time)
  726.         os.system(sox_params)
  727.         y = load_audio(tar_filename)
  728.         return y

  729. def augment_audio_with_sox(path, sample_rate, tempo, gain):
  730.     """
  731.     Changes tempo and gain of the recording with sox and loads it.
  732.     """
  733.     with NamedTemporaryFile(suffix=".wav") as augmented_file:
  734.         augmented_filename = augmented_file.name
  735.         sox_augment_params = ["tempo", "{:.3f}".format(
  736.             tempo), "gain", "{:.3f}".format(gain)]
  737.         sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(
  738.             path, sample_rate, augmented_filename, " ".join(sox_augment_params))
  739.         os.system(sox_params)
  740.         y = load_audio(augmented_filename)
  741.         return y


  742. def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), gain_range=(-6, 8)):
  743.     """
  744.     Picks tempo and gain uniformly, applies it to the utterance by using sox utility.
  745.     Returns the augmented utterance.
  746.     """
  747.     low_tempo, high_tempo = tempo_range
  748.     tempo_value = np.random.uniform(low=low_tempo, high=high_tempo)
  749.     low_gain, high_gain = gain_range
  750.     gain_value = np.random.uniform(low=low_gain, high=high_gain)
  751.     audio = augment_audio_with_sox(path=path, sample_rate=sample_rate,
  752.                                    tempo=tempo_value, gain=gain_value)
  753.     return audio

  754. class AudioParser(object):
  755.     def parse_transcript(self, transcript_path):
  756.         """
  757.         :param transcript_path: Path where transcript is stored from the manifest file
  758.         :return: Transcript in training/testing format
  759.         """
  760.         raise NotImplementedError

  761.     def parse_audio(self, audio_path):
  762.         """
  763.         :param audio_path: Path where audio is stored from the manifest file
  764.         :return: Audio in training/testing format
  765.         """
  766.         raise NotImplementedError


  767. class SpectrogramParser(AudioParser):
  768.     def __init__(self, audio_conf, normalize=False, augment=False):
  769.         """
  770.         Parses audio file into spectrogram with optional normalization and various augmentations
  771.         :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
  772.         :param normalize(default False):  Apply standard mean and deviation normalization to audio tensor
  773.         :param augment(default False):  Apply random tempo and gain perturbations
  774.         """
  775.         super(SpectrogramParser, self).__init__()
  776.         self.window_stride = audio_conf['window_stride']
  777.         self.window_size = audio_conf['window_size']
  778.         self.sample_rate = audio_conf['sample_rate']
  779.         self.window = windows.get(audio_conf['window'], windows['hamming'])
  780.         self.normalize = normalize
  781.         self.augment = augment
  782.         self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], self.sample_rate,
  783.                                             audio_conf['noise_levels']) if audio_conf.get(
  784.             'noise_dir') is not None else None
  785.         self.noise_prob = audio_conf.get('noise_prob')

  786.     def parse_audio(self, audio_path):
  787.         if self.augment:
  788.             y = load_randomly_augmented_audio(audio_path, self.sample_rate)
  789.         else:
  790.             y = load_audio(audio_path)

  791.         if self.noiseInjector:
  792.             # logging.info("inject noise")
  793.             add_noise = np.random.binomial(1, self.noise_prob)
  794.             if add_noise:
  795.                 y = self.noiseInjector.inject_noise(y)

  796.         n_fft = int(self.sample_rate * self.window_size)
  797.         win_length = n_fft
  798.         hop_length = int(self.sample_rate * self.window_stride)

  799.         # Short-time Fourier transform (STFT)
  800.         D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length,
  801.                          win_length=win_length, window=self.window)
  802.         spect, phase = librosa.magphase(D)

  803.         # S = log(S+1)
  804.         spect = np.log1p(spect)
  805.         spect = torch.FloatTensor(spect)

  806.         if self.normalize:
  807.             mean = spect.mean()
  808.             std = spect.std()
  809.             spect.add_(-mean)
  810.             spect.div_(std)

  811.         return spect

  812.     def parse_transcript(self, transcript_path):
  813.         raise NotImplementedError


  814. class SpectrogramDataset(Dataset, SpectrogramParser):
  815.     def __init__(self, audio_conf, manifest_filepath_list, label2id, normalize=False, augment=False):
  816.         """
  817.         Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
  818.         a comma. Each new line is a different sample. Example below:
  819.         /path/to/audio.wav,/path/to/audio.txt
  820.         ...
  821.         :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
  822.         :param manifest_filepath: Path to manifest csv as describe above
  823.         :param labels: String containing all the possible characters to map to
  824.         :param normalize: Apply standard mean and deviation normalization to audio tensor
  825.         :param augment(default False):  Apply random tempo and gain perturbations
  826.         """
  827.         self.max_size = 0
  828.         self.ids_list = []
  829.         for i in range(len(manifest_filepath_list)):
  830.             manifest_filepath = manifest_filepath_list[i]
  831.             with open(manifest_filepath) as f:
  832.                 ids = f.readlines()

  833.             ids = [x.strip().split(',') for x in ids]
  834.             self.ids_list.append(ids)
  835.             self.max_size = max(len(ids), self.max_size)

  836.         self.manifest_filepath_list = manifest_filepath_list
  837.         self.label2id = label2id
  838.         super(SpectrogramDataset, self).__init__(
  839.             audio_conf, normalize, augment)

  840.     def __getitem__(self, index):
  841.         random_id = random.randint(0, len(self.ids_list)-1)
  842.         ids = self.ids_list[random_id]
  843.         sample = ids[index % len(ids)]
  844.         audio_path, transcript_path = sample[0], sample[1]
  845.         src_max_len = 4000
  846.         spect = self.parse_audio(audio_path)[:,:src_max_len]
  847.         transcript = self.parse_transcript(transcript_path)
  848.         return spect, transcript

  849.     def parse_transcript(self, transcript_path):
  850.         with open(transcript_path, 'r', encoding='utf8') as transcript_file:
  851.             transcript = SOS_CHAR + transcript_file.read().replace('\n', '').lower() + EOS_CHAR

  852.         transcript = list(
  853.             filter(None, [self.label2id.get(x) for x in list(transcript)]))
  854.         return transcript

  855.     def __len__(self):
  856.         return self.max_size

  857. class NoiseInjection(object):
  858.     def __init__(self,
  859.                  path=None,
  860.                  sample_rate=16000,
  861.                  noise_levels=(0, 0.5)):
  862.         """
  863.         Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added.
  864.         Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py
  865.         """
  866.         if not os.path.exists(path):
  867.             print("Directory doesn't exist: {}".format(path))
  868.             raise IOError
  869.         self.paths = path is not None and librosa.util.find_files(path)
  870.         self.sample_rate = sample_rate
  871.         self.noise_levels = noise_levels

  872.     def inject_noise(self, data):
  873.         noise_path = np.random.choice(self.paths)
  874.         noise_level = np.random.uniform(*self.noise_levels)
  875.         return self.inject_noise_sample(data, noise_path, noise_level)

  876.     def inject_noise_sample(self, data, noise_path, noise_level):
  877.         noise_len = get_audio_length(noise_path)
  878.         data_len = len(data) / self.sample_rate
  879.         noise_start = np.random.rand() * (noise_len - data_len)
  880.         noise_end = noise_start + data_len
  881.         noise_dst = audio_with_sox(
  882.             noise_path, self.sample_rate, noise_start, noise_end)
  883.         assert len(data) == len(noise_dst)
  884.         noise_energy = np.sqrt(noise_dst.dot(noise_dst) / noise_dst.size)
  885.         data_energy = np.sqrt(data.dot(data) / data.size)
  886.         data += noise_level * noise_dst * data_energy / noise_energy
  887.         return data

  888. def _collate_fn(batch):
  889.     def func(p):
  890.         return p[0].size(1)

  891.     def func_tgt(p):
  892.         return len(p[1])

  893.     # descending sorted
  894.     batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)

  895.     max_seq_len = max(batch, key=func)[0].size(1)
  896.     freq_size = max(batch, key=func)[0].size(0)
  897.     max_tgt_len = len(max(batch, key=func_tgt)[1])

  898.     inputs = torch.zeros(len(batch), 1, freq_size, max_seq_len)
  899.     input_sizes = torch.IntTensor(len(batch))
  900.     input_percentages = torch.FloatTensor(len(batch))

  901.     targets = torch.zeros(len(batch), max_tgt_len).long()
  902.     target_sizes = torch.IntTensor(len(batch))

  903.     for x in range(len(batch)):
  904.         sample = batch[x]
  905.         input_data = sample[0]
  906.         target = sample[1]
  907.         seq_length = input_data.size(1)
  908.         input_sizes[x] = seq_length
  909.         inputs[x][0].narrow(1, 0, seq_length).copy_(input_data)
  910.         input_percentages[x] = seq_length / float(max_seq_len)
  911.         target_sizes[x] = len(target)
  912.         targets[x][:len(target)] = torch.IntTensor(target)

  913.     return inputs, targets, input_percentages, input_sizes, target_sizes

  914. class AudioDataLoader(DataLoader):
  915.     def __init__(self, *args, **kwargs):
  916.         super(AudioDataLoader, self).__init__(*args, **kwargs)
  917.         self.collate_fn = _collate_fn

  918. class BucketingSampler(Sampler):
  919.     def __init__(self, data_source, batch_size=1):
  920.         """
  921.         Samples batches assuming they are in order of size to batch similarly sized samples together.
  922.         """
  923.         super(BucketingSampler, self).__init__(data_source)
  924.         self.data_source = data_source
  925.         ids = list(range(0, len(data_source)))
  926.         self.bins = [ids[i:i + batch_size]
  927.                      for i in range(0, len(ids), batch_size)]
  928.     def __iter__(self):
  929.         for ids in self.bins:
  930.             np.random.shuffle(ids)
  931.             yield ids
  932.     def __len__(self):
  933.         return len(self.bins)
  934.     def shuffle(self, epoch):
  935.         np.random.shuffle(self.bins)

  936. if __name__ == '__main__':
  937.     batch_size=8
  938.     dim_emb=512
  939.     dim_inner=1024
  940.     dim_input=161
  941.     dim_key=64
  942.     dim_model=512
  943.     dim_value=64
  944.     dropout=0.1
  945.     emb_trg_sharing=False
  946.     epochs=300
  947.     # feat_extractor='vgg_cnn'
  948.     feat_extractor=''
  949.     labels_path=r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\LabelWav.json'
  950.     noise_dir=None
  951.     noise_max=0.5
  952.     noise_min=0.0
  953.     noise_prob=0.4
  954.     num_heads=5
  955.     num_layers=3
  956.     num_workers=0
  957.     prob_weight=1.0
  958.     sample_rate=16000
  959.     shuffle=False
  960.     src_max_len=4000
  961.     test_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_test.txt']
  962.     tgt_max_len=1000
  963.     train_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_train.txt']
  964.     valid_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_val.txt']
  965.     window='hamming'
  966.     window_size=0.02
  967.     window_stride=0.01
  968.    
  969.     print("="*50)
  970.     if not os.path.exists("./log"):
  971.         os.mkdir("./log")
  972.     audio_conf = dict(sample_rate = sample_rate,
  973.                       window_size = window_size,
  974.                       window_stride = window_stride,
  975.                       window = window,
  976.                       noise_dir = noise_dir,
  977.                       noise_prob = noise_prob,
  978.                       noise_levels = (noise_min, noise_max))

  979.     with open(labels_path, 'r', encoding="utf-8") as label_file:
  980.         labels = str(''.join(json.load(label_file)))
  981.     # add PAD_CHAR, SOS_CHAR, EOS_CHAR
  982.     labels = PAD_CHAR + SOS_CHAR + EOS_CHAR + labels
  983.     label2id, id2label = {}, {}
  984.     count = 0
  985.     for i in range(len(labels)):
  986.         if labels[i] not in label2id:
  987.             label2id[labels[i]] = count
  988.             id2label[count] = labels[i]
  989.             count += 1
  990.         else:
  991.             print("multiple label: ", labels[i])

  992.     train_data = SpectrogramDataset(audio_conf, manifest_filepath_list= train_manifest_list, label2id=label2id, normalize=True, augment = False)
  993.     train_sampler = BucketingSampler(train_data, batch_size = batch_size)
  994.     train_loader = AudioDataLoader(train_data, num_workers = num_workers, batch_sampler = train_sampler)

  995.     valid_loader_list, test_loader_list = [], []
  996.     for i in range(len(valid_manifest_list)):
  997.         valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[ valid_manifest_list[i]], label2id=label2id, normalize=True, augment=False)
  998.         valid_loader = AudioDataLoader(valid_data, num_workers = num_workers, batch_size= batch_size)
  999.         valid_loader_list.append(valid_loader)

  1000.     start_epoch = 0
  1001.     metrics = None
  1002.     model = init_transformer_model2(num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dim_emb,
  1003.                                     src_max_len, tgt_max_len, dropout, emb_trg_sharing, feat_extractor, label2id, id2label)
  1004.     model = model.cuda(0)
  1005.     train(model, train_loader, train_sampler, valid_loader_list, start_epoch, epochs, label2id, id2label, metrics)
复制代码






回复

使用道具 举报

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2023-7-30 20:58:10 | 显示全部楼层
算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复 支持 反对

使用道具 举报

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2023-8-12 10:35:05 | 显示全部楼层
踩坑:
【1】libtorch1.4+torchvision-0.5.0环境下安装librosa
https://blog.csdn.net/ysw123123/ ... A%22ysw123123%22%7D
【2】更新自己的安装包:conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=10.2 -c pytorch



算法QQ  3283892722
群智能算法链接http://halcom.cn/forum.php?mod=forumdisplay&fid=73
回复 支持 反对

使用道具 举报

1294

主题

1520

帖子

110

金钱

管理员

Rank: 9Rank: 9Rank: 9

积分
22633
发表于 2024-1-3 23:15:01 | 显示全部楼层
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

Python|Opencv|MATLAB|Halcom.cn ( 蜀ICP备16027072号 )

GMT+8, 2024-4-27 22:11 , Processed in 0.216198 second(s), 22 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表