|
End2End-ASR-Pytorch
B站视频:
【1】ASR调试视频https://www.bilibili.com/video/B ... 24aa2e8ea036f9d24f4
【2】ASR安装torchaudio
【3】ASR端到端代码讲解裁剪 https://www.bilibili.com/video/B ... 3af3339cdeadc088281
【4】ASR训练AI模型视频
【5】Step5_ASR代码简化处理讲解
- import sys
- import os
- import json
- import math
- import numpy as np
- import random
- import unicodedata
- import scipy.signal
- import subprocess
- from tempfile import NamedTemporaryFile
- import Levenshtein as Lev
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.autograd import Variable
- from torch.utils.data import DataLoader
- from torch.utils.data import Dataset
- from torch.utils.data.sampler import Sampler
- import librosa
- import torchaudio
- windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, 'bartlett': scipy.signal.bartlett}
- PAD_TOKEN = 0
- SOS_TOKEN = 1
- EOS_TOKEN = 2
- PAD_CHAR = "¶"
- SOS_CHAR = "§"
- EOS_CHAR = "¤"
- def calculate_cer(s1, s2):
- """
- Computes the Character Error Rate, defined as the edit distance.
- Arguments:
- s1 (string): space-separated sentence (hyp)
- s2 (string): space-separated sentence (gold)
- """
- return Lev.distance(s1, s2)
- def calculate_wer(s1, s2):
- """
- Computes the Word Error Rate, defined as the edit distance between the
- two provided sentences after tokenizing to words.
- Arguments:
- s1 (string): space-separated sentence
- s2 (string): space-separated sentence
- """
- # build mapping of words to integers
- b = set(s1.split() + s2.split())
- word2char = dict(zip(b, range(len(b))))
- # map the words to a char array (Levenshtein packages only accepts
- # strings)
- w1 = [chr(word2char[w]) for w in s1.split()]
- w2 = [chr(word2char[w]) for w in s2.split()]
- return Lev.distance(''.join(w1), ''.join(w2))
- def calculate_metrics(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
- """
- Calculate metrics
- args:
- pred: B x T x C
- gold: B x T
- input_lengths: B (for CTC)
- target_lengths: B (for CTC)
- """
- loss = calculate_loss(pred, gold, input_lengths, target_lengths, smoothing, loss_type)
- # if loss_type == "ce":
- pred = pred.view(-1, pred.size(2)) # (B*T) x C
- gold = gold.contiguous().view(-1) # (B*T)
- pred = pred.max(1)[1]
- non_pad_mask = gold.ne(PAD_TOKEN)
- num_correct = pred.eq(gold)
- num_correct = num_correct.masked_select(non_pad_mask).sum().item()
- return loss, num_correct
- def calculate_loss(pred, gold, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
- """
- Calculate loss
- args:
- pred: B x T x C
- gold: B x T
- input_lengths: B (for CTC)
- target_lengths: B (for CTC)
- smoothing:
- type: ce|ctc (ctc => pytorch 1.0.0 or later)
- input_lengths: B (only for ctc)
- target_lengths: B (only for ctc)
- """
- pred = pred.view(-1, pred.size(2)) # (B*T) x C
- gold = gold.contiguous().view(-1) # (B*T)
- loss = F.cross_entropy(pred, gold, ignore_index = PAD_TOKEN, reduction="mean")
- return loss
- def train(model, train_loader, train_sampler, valid_loader_list, start_epoch, num_epochs, label2id, id2label, last_metrics=None):
- """
- Training
- args:
- model: Model object
- train_loader: DataLoader object of the training set
- valid_loader_list: a list of Validation DataLoader objects
- opt: Optimizer object
- start_epoch: start epoch (> 0 if you resume the process)
- num_epochs: last epoch
- last_metrics: (if resume)
- """
- best_valid_loss = 1000000000 if last_metrics is None else last_metrics['valid_loss']
- optimizer = torch.optim.Adam([dict(params = model.parameters(), lr = 0.001)])
- for epoch in range(start_epoch, num_epochs):
- sys.stdout.flush()
- total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0
- model.train()
- model = model.cuda(0)
- for i, (data) in enumerate(train_loader, start = 0):
- src, tgt, src_percentages, src_lengths, tgt_lengths = data
- src = src.cuda()
- tgt = tgt.cuda()
- optimizer.zero_grad()
- # opt.zero_grad()
- pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)
- try: # handle case for CTC
- strs_gold, strs_hyps = [], []
- for ut_gold in gold_seq:
- str_gold = ""
- for x in ut_gold:
- if int(x) == PAD_TOKEN:
- break
- str_gold = str_gold + id2label[int(x)]
- strs_gold.append(str_gold)
- for ut_hyp in hyp_seq:
- str_hyp = ""
- for x in ut_hyp:
- if int(x) == PAD_TOKEN:
- break
- str_hyp = str_hyp + id2label[int(x)]
- strs_hyps.append(str_hyp)
- except Exception as e:
- print(e)
- # logging.info("NaN predictions")
- continue
- seq_length = pred.size(1)
- sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)
- loss, num_correct = calculate_metrics(
- pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=0, loss_type='ce')
- if loss.item() == float('Inf'):
- # logging.info("Found infinity loss, masking")
- loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
- continue
- for j in range(len(strs_hyps)):
- strs_hyps[j] = strs_hyps[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
- strs_gold[j] = strs_gold[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
- cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
- wer = calculate_wer(strs_hyps[j], strs_gold[j])
- total_cer += cer
- total_wer += wer
- total_char += len(strs_gold[j].replace(' ', ''))
- total_word += len(strs_gold[j].split(" "))
- loss.backward()
- optimizer.step()
- # opt.step()
- total_loss += loss.item()
- print("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% ".format( (epoch+1), total_loss/(i+1), total_cer*100/total_char ) )
- # evaluate
- print("")
- model.eval()
- for ind in range(len(valid_loader_list)):
- valid_loader = valid_loader_list[ind]
- total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
- # valid_pbar = tqdm(iter(valid_loader), leave=True, total=len(valid_loader))
- # for i, (data) in enumerate(valid_pbar):
- for i, (data) in enumerate(valid_loader):
- src, tgt, src_percentages, src_lengths, tgt_lengths = data
- src = src.cuda()
- tgt = tgt.cuda()
- with torch.no_grad():
- pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)
- seq_length = pred.size(1)
- sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)
- loss, num_correct = calculate_metrics(
- pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=0, loss_type='ce')
- if loss.item() == float('Inf'):
- # logging.info("Found infinity loss, masking")
- loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
- continue
- try: # handle case for CTC
- strs_gold, strs_hyps = [], []
- for ut_gold in gold_seq:
- str_gold = ""
- for x in ut_gold:
- if int(x) == PAD_TOKEN:
- break
- str_gold = str_gold + id2label[int(x)]
- strs_gold.append(str_gold)
- for ut_hyp in hyp_seq:
- str_hyp = ""
- for x in ut_hyp:
- if int(x) == PAD_TOKEN:
- break
- str_hyp = str_hyp + id2label[int(x)]
- strs_hyps.append(str_hyp)
- except Exception as e:
- print(e)
- # logging.info("NaN predictions")
- continue
- for j in range(len(strs_hyps)):
- strs_hyps[j] = strs_hyps[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
- strs_gold[j] = strs_gold[j].replace(SOS_CHAR, '').replace(EOS_CHAR, '')
- cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
- wer = calculate_wer(strs_hyps[j], strs_gold[j])
- total_valid_cer += cer
- total_valid_wer += wer
- total_valid_char += len(strs_gold[j].replace(' ', ''))
- total_valid_word += len(strs_gold[j].split(" "))
- total_valid_loss += loss.item()
- print("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(ind, total_valid_loss/(i+1), total_valid_cer*100/total_valid_char))
- if epoch % 5 == 0:
- save_model(model, (epoch+1), metrics, label2id, id2label, best_model=False)
- # save the best model
- if best_valid_loss > total_valid_loss/len(valid_loader):
- best_valid_loss = total_valid_loss/len(valid_loader)
- save_model(model, (epoch+1), metrics, label2id, id2label, best_model=True)
- class Transformer(nn.Module):
- """
- Transformer class
- args:
- encoder: Encoder object
- decoder: Decoder object
- """
- def __init__(self, encoder, decoder, feat_extractor='vgg_cnn'):
- super(Transformer, self).__init__()
- self.encoder = encoder
- self.decoder = decoder
- self.id2label = decoder.id2label
- self.feat_extractor = feat_extractor
-
- for p in self.parameters():
- if p.dim() > 1:
- nn.init.xavier_uniform_(p)
- def forward(self, padded_input, input_lengths, padded_target, verbose=False):
- """
- args:
- padded_input: B x 1 (channel for spectrogram=1) x (freq) x T
- padded_input: B x T x D
- input_lengths: B
- padded_target: B x T
- output:
- pred: B x T x vocab
- gold: B x T
- """
- if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
- padded_input = self.conv(padded_input)
- # Reshaping features
- sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
- padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
- padded_input = padded_input.transpose(1, 2).contiguous() # BxTxH
- encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
- pred, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
- hyp_best_scores, hyp_best_ids = torch.topk(pred, 1, dim=2)
- hyp_seq = hyp_best_ids.squeeze(2)
- gold_seq = gold
- return pred, gold, hyp_seq, gold_seq
- def evaluate(self, padded_input, input_lengths, padded_target):
- """
- args:
- padded_input: B x T x D
- input_lengths: B
- padded_target: B x T
- output:
- batch_ids_nbest_hyps: list of nbest id
- batch_strs_nbest_hyps: list of nbest str
- batch_strs_gold: list of gold str
- """
- if self.feat_extractor == 'emb_cnn' or self.feat_extractor == 'vgg_cnn':
- padded_input = self.conv(padded_input)
- # Reshaping features
- sizes = padded_input.size() # B x H_1 (channel?) x H_2 x T
- padded_input = padded_input.view(sizes[0], sizes[1] * sizes[2], sizes[3])
- padded_input = padded_input.transpose(1, 2).contiguous() # BxTxH
- encoder_padded_outputs, _ = self.encoder(padded_input, input_lengths)
- hyp, gold, *_ = self.decoder(padded_target, encoder_padded_outputs, input_lengths)
- hyp_best_scores, hyp_best_ids = torch.topk(hyp, 1, dim=2)
-
- strs_gold = ["".join([self.id2label[int(x)] for x in gold_seq]) for gold_seq in gold]
- strs_hyps = self.decoder.greedy_search(encoder_padded_outputs)
-
- return _, strs_hyps, strs_gold
-
- class Encoder(nn.Module):
- """
- Encoder Transformer class
- """
- def __init__(self, num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dropout=0.1, src_max_length=2500):
- super(Encoder, self).__init__()
- self.dim_input = dim_input
- self.num_layers = num_layers
- self.num_heads = num_heads
- self.dim_model = dim_model
- self.dim_key = dim_key
- self.dim_value = dim_value
- self.dim_inner = dim_inner
- self.src_max_length = src_max_length
- self.dropout = nn.Dropout(dropout)
- self.dropout_rate = dropout
- self.input_linear = nn.Linear(dim_input, dim_model)
- self.layer_norm_input = nn.LayerNorm(dim_model)
- self.positional_encoding = PositionalEncoding(dim_model, src_max_length)
- self.layers = nn.ModuleList([EncoderLayer(num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=dropout) for _ in range(num_layers)])
- def forward(self, padded_input, input_lengths):
- """
- args:
- padded_input: B x T x D
- input_lengths: B
- return:
- output: B x T x H
- """
- encoder_self_attn_list = []
- # Prepare masks
- non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) # B x T x D
- seq_len = padded_input.size(1)
- self_attn_mask = get_attn_pad_mask(padded_input, input_lengths, seq_len) # B x T x T
- encoder_output = self.layer_norm_input(self.input_linear(
- padded_input)) + self.positional_encoding(padded_input)
- for layer in self.layers:
- encoder_output, self_attn = layer(
- encoder_output, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask)
- encoder_self_attn_list += [self_attn]
- return encoder_output, encoder_self_attn_list
- class EncoderLayer(nn.Module):
- """
- Encoder Layer Transformer class
- """
- def __init__(self, num_heads, dim_model, dim_inner, dim_key, dim_value, dropout=0.1):
- super(EncoderLayer, self).__init__()
- self.self_attn = MultiHeadAttention(
- num_heads, dim_model, dim_key, dim_value, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForwardWithConv(
- dim_model, dim_inner, dropout=dropout)
- def forward(self, enc_input, non_pad_mask=None, self_attn_mask=None):
- enc_output, self_attn = self.self_attn(
- enc_input, enc_input, enc_input, mask=self_attn_mask)
- enc_output *= non_pad_mask
- enc_output = self.pos_ffn(enc_output)
- enc_output *= non_pad_mask
- return enc_output, self_attn
- class Decoder(nn.Module):
- """
- Decoder Layer Transformer class
- """
- def __init__(self, id2label, num_src_vocab, num_trg_vocab, num_layers, num_heads, dim_emb, dim_model, dim_inner, dim_key, dim_value, dropout=0.1, trg_max_length=1000, emb_trg_sharing=False):
- super(Decoder, self).__init__()
- self.sos_id = SOS_TOKEN
- self.eos_id = EOS_TOKEN
- self.id2label = id2label
- self.num_src_vocab = num_src_vocab
- self.num_trg_vocab = num_trg_vocab
- self.num_layers = num_layers
- self.num_heads = num_heads
- self.dim_emb = dim_emb
- self.dim_model = dim_model
- self.dim_inner = dim_inner
- self.dim_key = dim_key
- self.dim_value = dim_value
- self.dropout_rate = dropout
- self.emb_trg_sharing = emb_trg_sharing
- self.trg_max_length = trg_max_length
- self.trg_embedding = nn.Embedding(num_trg_vocab, dim_emb, padding_idx = PAD_TOKEN)
- self.positional_encoding = PositionalEncoding(
- dim_model, max_length=trg_max_length)
- self.dropout = nn.Dropout(dropout)
- self.layers = nn.ModuleList([
- DecoderLayer(dim_model, dim_inner, num_heads,
- dim_key, dim_value, dropout=dropout)
- for _ in range(num_layers)
- ])
- self.output_linear = nn.Linear(dim_model, num_trg_vocab, bias=False)
- nn.init.xavier_normal_(self.output_linear.weight)
- if emb_trg_sharing:
- self.output_linear.weight = self.trg_embedding.weight
- self.x_logit_scale = (dim_model ** -0.5)
- else:
- self.x_logit_scale = 1.0
- def preprocess(self, padded_input):
- """
- Add SOS TOKEN and EOS TOKEN into padded_input
- """
- seq = [y[y != PAD_TOKEN] for y in padded_input]
- eos = seq[0].new([self.eos_id])
- sos = seq[0].new([self.sos_id])
- seq_in = [torch.cat([sos, y], dim=0) for y in seq]
- seq_out = [torch.cat([y, eos], dim=0) for y in seq]
- seq_in_pad = pad_list(seq_in, self.eos_id)
- seq_out_pad = pad_list(seq_out, PAD_TOKEN)
- assert seq_in_pad.size() == seq_out_pad.size()
- return seq_in_pad, seq_out_pad
- def forward(self, padded_input, encoder_padded_outputs, encoder_input_lengths):
- """
- args:
- padded_input: B x T
- encoder_padded_outputs: B x T x H
- encoder_input_lengths: B
- returns:
- pred: B x T x vocab
- gold: B x T
- """
- decoder_self_attn_list, decoder_encoder_attn_list = [], []
- seq_in_pad, seq_out_pad = self.preprocess(padded_input)
- # Prepare masks
- non_pad_mask = get_non_pad_mask(seq_in_pad, pad_idx = EOS_TOKEN)
- self_attn_mask_subseq = get_subsequent_mask(seq_in_pad)
- self_attn_mask_keypad = get_attn_key_pad_mask(seq_k=seq_in_pad, seq_q=seq_in_pad, pad_idx = EOS_TOKEN)
- self_attn_mask = (self_attn_mask_keypad + self_attn_mask_subseq).gt(0)
- output_length = seq_in_pad.size(1)
- dec_enc_attn_mask = get_attn_pad_mask(
- encoder_padded_outputs, encoder_input_lengths, output_length)
- decoder_output = self.dropout(self.trg_embedding(
- seq_in_pad) * self.x_logit_scale + self.positional_encoding(seq_in_pad))
- for layer in self.layers:
- decoder_output, decoder_self_attn, decoder_enc_attn = layer(
- decoder_output, encoder_padded_outputs, non_pad_mask=non_pad_mask, self_attn_mask=self_attn_mask, dec_enc_attn_mask=dec_enc_attn_mask)
- decoder_self_attn_list += [decoder_self_attn]
- decoder_encoder_attn_list += [decoder_enc_attn]
- seq_logit = self.output_linear(decoder_output)
- pred, gold = seq_logit, seq_out_pad
- return pred, gold, decoder_self_attn_list, decoder_encoder_attn_list
- def post_process_hyp(self, hyp):
- """
- args:
- hyp: list of hypothesis
- output:
- list of hypothesis (string)>
- """
- return "".join([self.id2label[int(x)] for x in hyp['yseq'][1:]])
- def greedy_search(self, encoder_padded_outputs, beam_width=2, lm_rescoring=False, lm=None, lm_weight=0.1, c_weight=1):
- """
- Greedy search, decode 1-best utterance
- args:
- encoder_padded_outputs: B x T x H
- output:
- batch_ids_nbest_hyps: list of nbest in ids (size B)
- batch_strs_nbest_hyps: list of nbest in strings (size B)
- """
- # max_seq_len = self.trg_max_length
- ys = torch.ones(encoder_padded_outputs.size(0),1).fill_(SOS_TOKEN).long() # batch_size x 1
- if True:
- ys = ys.cuda()
- decoded_words = []
- for t in range(300):
- # for t in range(max_seq_len):
- # print(t)
- # Prepare masks
- non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # batch_size x t x 1
- self_attn_mask = get_subsequent_mask(ys) # batch_size x t x t
- decoder_output = self.dropout(self.trg_embedding(ys) * self.x_logit_scale
- + self.positional_encoding(ys))
- for layer in self.layers:
- decoder_output, _, _ = layer(
- decoder_output, encoder_padded_outputs,
- non_pad_mask=non_pad_mask,
- self_attn_mask=self_attn_mask,
- dec_enc_attn_mask=None
- )
- prob = self.output_linear(decoder_output) # batch_size x t x label_size
- _, next_word = torch.max(prob[:, -1], dim=1)
- decoded_words.append([EOS_CHAR if ni.item() == EOS_TOKEN else self.id2label[ni.item()] for ni in next_word.view(-1)])
- next_word = next_word.unsqueeze(-1)
- if True:
- ys = torch.cat([ys, next_word.cuda()], dim=1)
- ys = ys.cuda()
- else:
- ys = torch.cat([ys, next_word], dim=1)
- sent = []
- for _, row in enumerate(np.transpose(decoded_words)):
- st = ''
- for e in row:
- if e == EOS_CHAR:
- break
- else:
- st += e
- sent.append(st)
- return sent
- class DecoderLayer(nn.Module):
- """
- Decoder Transformer class
- """
- def __init__(self, dim_model, dim_inner, num_heads, dim_key, dim_value, dropout=0.1):
- super(DecoderLayer, self).__init__()
- self.self_attn = MultiHeadAttention(
- num_heads, dim_model, dim_key, dim_value, dropout=dropout)
- self.encoder_attn = MultiHeadAttention(
- num_heads, dim_model, dim_key, dim_value, dropout=dropout)
- self.pos_ffn = PositionwiseFeedForwardWithConv(
- dim_model, dim_inner, dropout=dropout)
- def forward(self, decoder_input, encoder_output, non_pad_mask=None, self_attn_mask=None, dec_enc_attn_mask=None):
- decoder_output, decoder_self_attn = self.self_attn(
- decoder_input, decoder_input, decoder_input, mask=self_attn_mask)
- decoder_output *= non_pad_mask
- decoder_output, decoder_encoder_attn = self.encoder_attn(
- decoder_output, encoder_output, encoder_output, mask=dec_enc_attn_mask)
- decoder_output *= non_pad_mask
- decoder_output = self.pos_ffn(decoder_output)
- decoder_output *= non_pad_mask
- return decoder_output, decoder_self_attn, decoder_encoder_attn
- """
- General purpose functions
- """
- def pad_list(xs, pad_value):
- # From: espnet/src/nets/e2e_asr_th.py: pad_list()
- n_batch = len(xs)
- # max_len = max(x.size(0) for x in xs)
- # tgt_max_len=1000
- max_len = 1000
- pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
- for i in range(n_batch):
- pad[i, :xs[i].size(0)] = xs[i]
- return pad
- """
- Transformer common layers
- """
- def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None):
- """
- padding position is set to 0, either use input_lengths or pad_idx
- """
- assert input_lengths is not None or pad_idx is not None
- if input_lengths is not None:
- # padded_input: N x T x ..
- N = padded_input.size(0)
- non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # B x T
- for i in range(N):
- non_pad_mask[i, input_lengths[i]:] = 0
- if pad_idx is not None:
- # padded_input: N x T
- assert padded_input.dim() == 2
- non_pad_mask = padded_input.ne(pad_idx).float()
- # unsqueeze(-1) for broadcast
- return non_pad_mask.unsqueeze(-1)
- def get_attn_key_pad_mask(seq_k, seq_q, pad_idx):
- """
- For masking out the padding part of key sequence.
- """
- # Expand to fit the shape of key query attention matrix.
- len_q = seq_q.size(1)
- padding_mask = seq_k.eq(pad_idx)
- padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # B x T_Q x T_K
- return padding_mask
- def get_attn_pad_mask(padded_input, input_lengths, expand_length):
- """mask position is set to 1"""
- # N x Ti x 1
- non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths)
- # N x Ti, lt(1) like not operation
- pad_mask = non_pad_mask.squeeze(-1).lt(1)
- attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
- return attn_mask
- def get_subsequent_mask(seq):
- ''' For masking out the subsequent info. '''
- sz_b, len_s = seq.size()
- subsequent_mask = torch.triu(
- torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
- subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls
- return subsequent_mask
- class PositionalEncoding(nn.Module):
- """
- Positional Encoding class
- """
- def __init__(self, dim_model, max_length=2000):
- super(PositionalEncoding, self).__init__()
- pe = torch.zeros(max_length, dim_model, requires_grad=False)
- position = torch.arange(0, max_length).unsqueeze(1).float()
- exp_term = torch.exp(torch.arange(0, dim_model, 2).float() * -(math.log(10000.0) / dim_model))
- pe[:, 0::2] = torch.sin(position * exp_term) # take the odd (jump by 2)
- pe[:, 1::2] = torch.cos(position * exp_term) # take the even (jump by 2)
- pe = pe.unsqueeze(0)
- self.register_buffer('pe', pe)
- def forward(self, input):
- """
- args:
- input: B x T x D
- output:
- tensor: B x T
- """
- return self.pe[:, :input.size(1)]
- class PositionwiseFeedForward(nn.Module):
- """
- Position-wise Feedforward Layer class
- FFN(x) = max(0, xW1 + b1) W2+ b2
- """
- def __init__(self, dim_model, dim_ff, dropout=0.1):
- super(PositionwiseFeedForward, self).__init__()
- self.linear_1 = nn.Linear(dim_model, dim_ff)
- self.linear_2 = nn.Linear(dim_ff, dim_model)
- self.dropout = nn.Dropout(dropout)
- self.layer_norm = nn.LayerNorm(dim_model)
- def forward(self, x):
- """
- args:
- x: tensor
- output:
- y: tensor
- """
- residual = x
- output = self.dropout(self.linear_2(F.relu(self.linear_1(x))))
- output = self.layer_norm(output + residual)
- return output
- class PositionwiseFeedForwardWithConv(nn.Module):
- """
- Position-wise Feedforward Layer Implementation with Convolution class
- """
- def __init__(self, dim_model, dim_hidden, dropout=0.1):
- super(PositionwiseFeedForwardWithConv, self).__init__()
- self.conv_1 = nn.Conv1d(dim_model, dim_hidden, 1)
- self.conv_2 = nn.Conv1d(dim_hidden, dim_model, 1)
- self.dropout = nn.Dropout(dropout)
- self.layer_norm = nn.LayerNorm(dim_model)
- def forward(self, x):
- residual = x
- output = x.transpose(1, 2)
- output = self.conv_2(F.relu(self.conv_1(output)))
- output = output.transpose(1, 2)
- output = self.dropout(output)
- output = self.layer_norm(output + residual)
- return output
- class MultiHeadAttention(nn.Module):
- def __init__(self, num_heads, dim_model, dim_key, dim_value, dropout=0.1):
- super(MultiHeadAttention, self).__init__()
- self.num_heads = num_heads
- self.dim_model = dim_model
- self.dim_key = dim_key
- self.dim_value = dim_value
- self.query_linear = nn.Linear(dim_model, num_heads * dim_key)
- self.key_linear = nn.Linear(dim_model, num_heads * dim_key)
- self.value_linear = nn.Linear(dim_model, num_heads * dim_value)
- nn.init.normal_(self.query_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
- nn.init.normal_(self.key_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_key)))
- nn.init.normal_(self.value_linear.weight, mean=0, std=np.sqrt(2.0 / (self.dim_model + self.dim_value)))
- self.attention = ScaledDotProductAttention(temperature=np.power(dim_key, 0.5), attn_dropout=dropout)
- self.layer_norm = nn.LayerNorm(dim_model)
- self.output_linear = nn.Linear(num_heads * dim_value, dim_model)
- nn.init.xavier_normal_(self.output_linear.weight)
-
- self.dropout = nn.Dropout(dropout)
- def forward(self, query, key, value, mask=None):
- """
- query: B x T_Q x H, key: B x T_K x H, value: B x T_V x H
- mask: B x T x T (attention mask)
- """
- batch_size, len_query, _ = query.size()
- batch_size, len_key, _ = key.size()
- batch_size, len_value, _ = value.size()
- residual = query
- query = self.query_linear(query).view(batch_size, len_query, self.num_heads, self.dim_key) # B x T_Q x num_heads x H_K
- key = self.key_linear(key).view(batch_size, len_key, self.num_heads, self.dim_key) # B x T_K x num_heads x H_K
- value = self.value_linear(value).view(batch_size, len_value, self.num_heads, self.dim_value) # B x T_V x num_heads x H_V
- query = query.permute(2, 0, 1, 3).contiguous().view(-1, len_query, self.dim_key) # (num_heads * B) x T_Q x H_K
- key = key.permute(2, 0, 1, 3).contiguous().view(-1, len_key, self.dim_key) # (num_heads * B) x T_K x H_K
- value = value.permute(2, 0, 1, 3).contiguous().view(-1, len_value, self.dim_value) # (num_heads * B) x T_V x H_V
- if mask is not None:
- mask = mask.repeat(self.num_heads, 1, 1) # (B * num_head) x T x T
-
- output, attn = self.attention(query, key, value, mask=mask)
- output = output.view(self.num_heads, batch_size, len_query, self.dim_value) # num_heads x B x T_Q x H_V
- output = output.permute(1, 2, 0, 3).contiguous().view(batch_size, len_query, -1) # B x T_Q x (num_heads * H_V)
- output = self.dropout(self.output_linear(output)) # B x T_Q x H_O
- output = self.layer_norm(output + residual)
- return output, attn
- class ScaledDotProductAttention(nn.Module):
- ''' Scaled Dot-Product Attention '''
- def __init__(self, temperature, attn_dropout=0.1):
- super().__init__()
- self.temperature = temperature
- self.dropout = nn.Dropout(attn_dropout)
- self.softmax = nn.Softmax(dim=2)
- def forward(self, q, k, v, mask=None):
- """
- """
- attn = torch.bmm(q, k.transpose(1, 2))
- attn = attn / self.temperature
- if mask is not None:
- attn = attn.masked_fill(mask, -np.inf)
- attn = self.softmax(attn)
- attn = self.dropout(attn)
- output = torch.bmm(attn, v)
- return output, attn
- def save_model(model, epoch, metrics, label2id, id2label, best_model=False):
- """
- Saving model, TODO adding history
- """
- save_folder = 'models/'
- name = 'model'
- if best_model:
- save_path = "{}/{}/best_model.th".format(save_folder, name)
- else:
- save_path = "{}/{}/epoch_{}.th".format(save_folder, name, epoch)
- if not os.path.exists(save_folder + "/" + name):
- os.makedirs(save_folder + "/" + name)
- print("SAVE MODEL to", save_path)
- args = {
- 'label2id': label2id,
- 'id2label': id2label,
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'metrics': metrics
- }
- torch.save(args, save_path)
- def init_transformer_model2(num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dim_emb, src_max_len, tgt_max_len, dropout, emb_trg_sharing, feat_extractor, label2id, id2label):
- """
- Initiate a new transformer object
- """
- encoder = Encoder(num_layers, num_heads=num_heads, dim_model=dim_model, dim_key=dim_key,
- dim_value=dim_value, dim_input=dim_input, dim_inner=dim_inner, src_max_length=src_max_len, dropout=dropout)
- decoder = Decoder(id2label, num_src_vocab=len(label2id), num_trg_vocab=len(label2id), num_layers=num_layers, num_heads=num_heads,
- dim_emb=dim_emb, dim_model=dim_model, dim_inner=dim_inner, dim_key=dim_key, dim_value=dim_value, trg_max_length=tgt_max_len, dropout=dropout, emb_trg_sharing=emb_trg_sharing)
- model = Transformer(encoder, decoder, feat_extractor=feat_extractor)
- return model
- def load_audio(path):
- # sound, _ = torchaudio.load(path, normalization=True)
- sound, _ = torchaudio.load(path)
- sound = sound.numpy().T
- if len(sound.shape) > 1:
- if sound.shape[1] == 1:
- sound = sound.squeeze()
- else:
- sound = sound.mean(axis=1) # multiple channels, average
- return sound
- def get_audio_length(path):
- output = subprocess.check_output(
- ['soxi -D "%s"' % path.strip()], shell=True)
- return float(output)
- def audio_with_sox(path, sample_rate, start_time, end_time):
- """
- crop and resample the recording with sox and loads it.
- """
- with NamedTemporaryFile(suffix=".wav") as tar_file:
- tar_filename = tar_file.name
- sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate, tar_filename, start_time, end_time)
- os.system(sox_params)
- y = load_audio(tar_filename)
- return y
- def augment_audio_with_sox(path, sample_rate, tempo, gain):
- """
- Changes tempo and gain of the recording with sox and loads it.
- """
- with NamedTemporaryFile(suffix=".wav") as augmented_file:
- augmented_filename = augmented_file.name
- sox_augment_params = ["tempo", "{:.3f}".format(
- tempo), "gain", "{:.3f}".format(gain)]
- sox_params = "sox "{}" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(
- path, sample_rate, augmented_filename, " ".join(sox_augment_params))
- os.system(sox_params)
- y = load_audio(augmented_filename)
- return y
- def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), gain_range=(-6, 8)):
- """
- Picks tempo and gain uniformly, applies it to the utterance by using sox utility.
- Returns the augmented utterance.
- """
- low_tempo, high_tempo = tempo_range
- tempo_value = np.random.uniform(low=low_tempo, high=high_tempo)
- low_gain, high_gain = gain_range
- gain_value = np.random.uniform(low=low_gain, high=high_gain)
- audio = augment_audio_with_sox(path=path, sample_rate=sample_rate,
- tempo=tempo_value, gain=gain_value)
- return audio
- class AudioParser(object):
- def parse_transcript(self, transcript_path):
- """
- :param transcript_path: Path where transcript is stored from the manifest file
- :return: Transcript in training/testing format
- """
- raise NotImplementedError
- def parse_audio(self, audio_path):
- """
- :param audio_path: Path where audio is stored from the manifest file
- :return: Audio in training/testing format
- """
- raise NotImplementedError
- class SpectrogramParser(AudioParser):
- def __init__(self, audio_conf, normalize=False, augment=False):
- """
- Parses audio file into spectrogram with optional normalization and various augmentations
- :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
- :param normalize(default False): Apply standard mean and deviation normalization to audio tensor
- :param augment(default False): Apply random tempo and gain perturbations
- """
- super(SpectrogramParser, self).__init__()
- self.window_stride = audio_conf['window_stride']
- self.window_size = audio_conf['window_size']
- self.sample_rate = audio_conf['sample_rate']
- self.window = windows.get(audio_conf['window'], windows['hamming'])
- self.normalize = normalize
- self.augment = augment
- self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], self.sample_rate,
- audio_conf['noise_levels']) if audio_conf.get(
- 'noise_dir') is not None else None
- self.noise_prob = audio_conf.get('noise_prob')
- def parse_audio(self, audio_path):
- if self.augment:
- y = load_randomly_augmented_audio(audio_path, self.sample_rate)
- else:
- y = load_audio(audio_path)
- if self.noiseInjector:
- # logging.info("inject noise")
- add_noise = np.random.binomial(1, self.noise_prob)
- if add_noise:
- y = self.noiseInjector.inject_noise(y)
- n_fft = int(self.sample_rate * self.window_size)
- win_length = n_fft
- hop_length = int(self.sample_rate * self.window_stride)
- # Short-time Fourier transform (STFT)
- D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length,
- win_length=win_length, window=self.window)
- spect, phase = librosa.magphase(D)
- # S = log(S+1)
- spect = np.log1p(spect)
- spect = torch.FloatTensor(spect)
- if self.normalize:
- mean = spect.mean()
- std = spect.std()
- spect.add_(-mean)
- spect.div_(std)
- return spect
- def parse_transcript(self, transcript_path):
- raise NotImplementedError
- class SpectrogramDataset(Dataset, SpectrogramParser):
- def __init__(self, audio_conf, manifest_filepath_list, label2id, normalize=False, augment=False):
- """
- Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by
- a comma. Each new line is a different sample. Example below:
- /path/to/audio.wav,/path/to/audio.txt
- ...
- :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds
- :param manifest_filepath: Path to manifest csv as describe above
- :param labels: String containing all the possible characters to map to
- :param normalize: Apply standard mean and deviation normalization to audio tensor
- :param augment(default False): Apply random tempo and gain perturbations
- """
- self.max_size = 0
- self.ids_list = []
- for i in range(len(manifest_filepath_list)):
- manifest_filepath = manifest_filepath_list[i]
- with open(manifest_filepath) as f:
- ids = f.readlines()
- ids = [x.strip().split(',') for x in ids]
- self.ids_list.append(ids)
- self.max_size = max(len(ids), self.max_size)
- self.manifest_filepath_list = manifest_filepath_list
- self.label2id = label2id
- super(SpectrogramDataset, self).__init__(
- audio_conf, normalize, augment)
- def __getitem__(self, index):
- random_id = random.randint(0, len(self.ids_list)-1)
- ids = self.ids_list[random_id]
- sample = ids[index % len(ids)]
- audio_path, transcript_path = sample[0], sample[1]
- src_max_len = 4000
- spect = self.parse_audio(audio_path)[:,:src_max_len]
- transcript = self.parse_transcript(transcript_path)
- return spect, transcript
- def parse_transcript(self, transcript_path):
- with open(transcript_path, 'r', encoding='utf8') as transcript_file:
- transcript = SOS_CHAR + transcript_file.read().replace('\n', '').lower() + EOS_CHAR
- transcript = list(
- filter(None, [self.label2id.get(x) for x in list(transcript)]))
- return transcript
- def __len__(self):
- return self.max_size
- class NoiseInjection(object):
- def __init__(self,
- path=None,
- sample_rate=16000,
- noise_levels=(0, 0.5)):
- """
- Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added.
- Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py
- """
- if not os.path.exists(path):
- print("Directory doesn't exist: {}".format(path))
- raise IOError
- self.paths = path is not None and librosa.util.find_files(path)
- self.sample_rate = sample_rate
- self.noise_levels = noise_levels
- def inject_noise(self, data):
- noise_path = np.random.choice(self.paths)
- noise_level = np.random.uniform(*self.noise_levels)
- return self.inject_noise_sample(data, noise_path, noise_level)
- def inject_noise_sample(self, data, noise_path, noise_level):
- noise_len = get_audio_length(noise_path)
- data_len = len(data) / self.sample_rate
- noise_start = np.random.rand() * (noise_len - data_len)
- noise_end = noise_start + data_len
- noise_dst = audio_with_sox(
- noise_path, self.sample_rate, noise_start, noise_end)
- assert len(data) == len(noise_dst)
- noise_energy = np.sqrt(noise_dst.dot(noise_dst) / noise_dst.size)
- data_energy = np.sqrt(data.dot(data) / data.size)
- data += noise_level * noise_dst * data_energy / noise_energy
- return data
- def _collate_fn(batch):
- def func(p):
- return p[0].size(1)
- def func_tgt(p):
- return len(p[1])
- # descending sorted
- batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True)
- max_seq_len = max(batch, key=func)[0].size(1)
- freq_size = max(batch, key=func)[0].size(0)
- max_tgt_len = len(max(batch, key=func_tgt)[1])
- inputs = torch.zeros(len(batch), 1, freq_size, max_seq_len)
- input_sizes = torch.IntTensor(len(batch))
- input_percentages = torch.FloatTensor(len(batch))
- targets = torch.zeros(len(batch), max_tgt_len).long()
- target_sizes = torch.IntTensor(len(batch))
- for x in range(len(batch)):
- sample = batch[x]
- input_data = sample[0]
- target = sample[1]
- seq_length = input_data.size(1)
- input_sizes[x] = seq_length
- inputs[x][0].narrow(1, 0, seq_length).copy_(input_data)
- input_percentages[x] = seq_length / float(max_seq_len)
- target_sizes[x] = len(target)
- targets[x][:len(target)] = torch.IntTensor(target)
- return inputs, targets, input_percentages, input_sizes, target_sizes
- class AudioDataLoader(DataLoader):
- def __init__(self, *args, **kwargs):
- super(AudioDataLoader, self).__init__(*args, **kwargs)
- self.collate_fn = _collate_fn
- class BucketingSampler(Sampler):
- def __init__(self, data_source, batch_size=1):
- """
- Samples batches assuming they are in order of size to batch similarly sized samples together.
- """
- super(BucketingSampler, self).__init__(data_source)
- self.data_source = data_source
- ids = list(range(0, len(data_source)))
- self.bins = [ids[i:i + batch_size]
- for i in range(0, len(ids), batch_size)]
- def __iter__(self):
- for ids in self.bins:
- np.random.shuffle(ids)
- yield ids
- def __len__(self):
- return len(self.bins)
- def shuffle(self, epoch):
- np.random.shuffle(self.bins)
- if __name__ == '__main__':
- batch_size=8
- dim_emb=512
- dim_inner=1024
- dim_input=161
- dim_key=64
- dim_model=512
- dim_value=64
- dropout=0.1
- emb_trg_sharing=False
- epochs=300
- # feat_extractor='vgg_cnn'
- feat_extractor=''
- labels_path=r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\LabelWav.json'
- noise_dir=None
- noise_max=0.5
- noise_min=0.0
- noise_prob=0.4
- num_heads=5
- num_layers=3
- num_workers=0
- prob_weight=1.0
- sample_rate=16000
- shuffle=False
- src_max_len=4000
- test_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_test.txt']
- tgt_max_len=1000
- train_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_train.txt']
- valid_manifest_list=[r'D:\2-LearningCode\902-ASR\AISHELL-1\data_aishell\ASRInfo_val.txt']
- window='hamming'
- window_size=0.02
- window_stride=0.01
-
- print("="*50)
- if not os.path.exists("./log"):
- os.mkdir("./log")
- audio_conf = dict(sample_rate = sample_rate,
- window_size = window_size,
- window_stride = window_stride,
- window = window,
- noise_dir = noise_dir,
- noise_prob = noise_prob,
- noise_levels = (noise_min, noise_max))
- with open(labels_path, 'r', encoding="utf-8") as label_file:
- labels = str(''.join(json.load(label_file)))
- # add PAD_CHAR, SOS_CHAR, EOS_CHAR
- labels = PAD_CHAR + SOS_CHAR + EOS_CHAR + labels
- label2id, id2label = {}, {}
- count = 0
- for i in range(len(labels)):
- if labels[i] not in label2id:
- label2id[labels[i]] = count
- id2label[count] = labels[i]
- count += 1
- else:
- print("multiple label: ", labels[i])
- train_data = SpectrogramDataset(audio_conf, manifest_filepath_list= train_manifest_list, label2id=label2id, normalize=True, augment = False)
- train_sampler = BucketingSampler(train_data, batch_size = batch_size)
- train_loader = AudioDataLoader(train_data, num_workers = num_workers, batch_sampler = train_sampler)
- valid_loader_list, test_loader_list = [], []
- for i in range(len(valid_manifest_list)):
- valid_data = SpectrogramDataset(audio_conf, manifest_filepath_list=[ valid_manifest_list[i]], label2id=label2id, normalize=True, augment=False)
- valid_loader = AudioDataLoader(valid_data, num_workers = num_workers, batch_size= batch_size)
- valid_loader_list.append(valid_loader)
- start_epoch = 0
- metrics = None
- model = init_transformer_model2(num_layers, num_heads, dim_model, dim_key, dim_value, dim_input, dim_inner, dim_emb,
- src_max_len, tgt_max_len, dropout, emb_trg_sharing, feat_extractor, label2id, id2label)
- model = model.cuda(0)
- train(model, train_loader, train_sampler, valid_loader_list, start_epoch, epochs, label2id, id2label, metrics)
复制代码
|
|