Tools¶
create_ontology
¶
__version__
special
¶
custom_style_2
¶
Database
¶
Source code in adviser/tools/create_ontology.py
class Database(object):
def __init__(self, path):
conn = sqlite3.connect(path)
cursor = conn.cursor()
self.tables = {}
# result will be (type, name, tbl_name, rootpage, sql)
cursor.execute("SELECT * FROM sqlite_master where type='table'")
for _, _, table, _, _ in cursor.fetchall():
self.tables[table] = DatabaseTable(table)
for table in self.tables.keys():
# get fields/slots
# result will be (id, name, type, not null, default, primary key)
cursor.execute(f"PRAGMA table_info({table});")
self.tables[table].fields = cursor.fetchall()
# make sure that fields are sorted according to field index (should be already anyway)
self.tables[table].fields = sorted(self.tables[table].fields, key=lambda field: field[0])
# get entries (especially for possible values)
cursor.execute(f"SELECT * FROM {table}")
self.tables[table].entries = cursor.fetchall()
# add user and system actions
def get_tables(self):
return list(self.tables.keys())
def get_slots(self, table):
return self.tables[table].get_slots()
def get_slot_values(self, table, slot):
return self.tables[table].get_slot_values(slot)
__init__(self, path)
special
¶
Source code in adviser/tools/create_ontology.py
def __init__(self, path):
conn = sqlite3.connect(path)
cursor = conn.cursor()
self.tables = {}
# result will be (type, name, tbl_name, rootpage, sql)
cursor.execute("SELECT * FROM sqlite_master where type='table'")
for _, _, table, _, _ in cursor.fetchall():
self.tables[table] = DatabaseTable(table)
for table in self.tables.keys():
# get fields/slots
# result will be (id, name, type, not null, default, primary key)
cursor.execute(f"PRAGMA table_info({table});")
self.tables[table].fields = cursor.fetchall()
# make sure that fields are sorted according to field index (should be already anyway)
self.tables[table].fields = sorted(self.tables[table].fields, key=lambda field: field[0])
# get entries (especially for possible values)
cursor.execute(f"SELECT * FROM {table}")
self.tables[table].entries = cursor.fetchall()
# add user and system actions
get_slot_values(self, table, slot)
¶
get_slots(self, table)
¶
get_tables(self)
¶
DatabaseTable
¶
Source code in adviser/tools/create_ontology.py
class DatabaseTable(object):
def __init__(self, name):
self.name = name
self.fields = []
self.entries = []
def _get_slot_id(self, slot):
for field in self.fields:
if field[1] == slot:
return field[0]
return -1
def get_slots(self):
return [field[1] for field in self.fields]
def get_slot_values(self, slot, dontcare = False):
# get slot id
id = self._get_slot_id(slot)
assert id >= 0, f"Slot '{slot}' is not part of the database table '{self.name}'"
values = sorted(list(set([entry[id] for entry in self.entries])))
if dontcare and not ('dontcare' in values or "do n't care" in values):
values.append('dontcare')
return values
__init__(self, name)
special
¶
get_slot_values(self, slot, dontcare=False)
¶
Source code in adviser/tools/create_ontology.py
def get_slot_values(self, slot, dontcare = False):
# get slot id
id = self._get_slot_id(slot)
assert id >= 0, f"Slot '{slot}' is not part of the database table '{self.name}'"
values = sorted(list(set([entry[id] for entry in self.entries])))
if dontcare and not ('dontcare' in values or "do n't care" in values):
values.append('dontcare')
return values
get_slots(self)
¶
get_defaults()
¶
run_questions(db)
¶
Source code in adviser/tools/create_ontology.py
def run_questions(db: Database):
# initialize with default values
answers = get_defaults()
questions = [
{
'type': 'list',
'qmark': '>>>',
'message': 'Select table to create ontology for',
'name': 'table',
'choices': [{'key': str(id), 'name': table, 'value': table} for id, table in enumerate(db.get_tables())],
'validate': lambda answer: 'You must choose at least one table.' \
if len(answer) == 0 else True
},
{
'type': 'input',
'qmark': '>>>',
'message': 'Enter the name of the domain:',
'name': 'domain',
'default': lambda answers: answers['table']
},
{
'type': 'list',
'qmark': '>>>',
'name': 'key',
'message': 'Which slot will be used as key? (The key uniquely identifies an entity in the database, e.g. the name in case of restaurants)',
'choices': lambda answers: [{'name': slot} for slot in db.get_slots(answers['table'])]
},
{
'type': 'checkbox',
'qmark': '>>>',
'name': 'requestable',
'message': 'Select user requestables',
'choices': lambda answers: [{'name': slot, 'checked': slot != 'id'} for slot in db.get_slots(answers['table'])]
},
{
'type': 'checkbox',
'qmark': '>>>',
'name': 'system_requestable',
'message': 'Select system requestables',
'choices': lambda answers: [{'name': slot} for slot in db.get_slots(answers['table'])]
},
{
'type': 'checkbox',
'qmark': '>>>',
'name': 'informable',
'message': 'Select informable slots',
'choices': lambda answers: [{'name': slot} for slot in db.get_slots(answers['table'])]
}]
answers_ = prompt(questions, style=custom_style_2)
# check whether there are answers (e.g. if the user cancels the prompt using Ctrl+c)
if not answers_:
exit()
answers.update(answers_)
# get values for informable slots
questions = [
{
'type': 'checkbox',
'qmark': '>>>',
'name': slot,
'message': f'Select values for informable slot {slot}',
'choices': [{'name': value, 'checked': value != 'dontcare'} for value in db.get_slot_values(answers['table'], slot)]
} for slot in answers['informable']
]
values = prompt(questions, style=custom_style_2)
# merge informable slot values with informable slots
answers['informable'] = {slot: values[slot] for slot in answers['informable'] if slot in values}
# get binary slots
questions = [
{
'type': 'checkbox',
'qmark': '>>>',
'name': 'binary',
'message': 'Select binary slots',
'choices': [{'name': slot, 'checked': set(db.get_slot_values(answers['table'], slot)) == {'true', 'false'}} for slot in list(answers['informable'].keys()) + answers['requestable'] + answers['system_requestable']]
}
]
answers_ = prompt(questions, style=custom_style_2)
# check whether there are answers (e.g. if the user cancels the prompt using Ctrl+c)
if not answers_:
exit()
answers.update(answers_)
return answers
espnet_minimal
special
¶
asr
special
¶
asr_utils
¶
add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55)
¶
Adds noise from a standard normal distribution to the gradients.
The standard deviation (sigma
) is controlled by the three hyper-parameters below.
sigma
goes to zero (no noise) with more iterations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.model |
Model. |
required |
iteration |
int |
Number of iterations. |
required |
duration |
int) {100, 1000} |
Number of durations to control the interval of the |
100 |
eta |
float) {0.01, 0.3, 1.0} |
The magnitude of |
1.0 |
scale_factor |
float) {0.55} |
The scale of |
0.55 |
Source code in adviser/tools/espnet_minimal/asr/asr_utils.py
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model (torch.nn.model): Model.
iteration (int): Number of iterations.
duration (int) {100, 1000}: Number of durations to control the interval of the `sigma` change.
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor (float) {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval ** scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise
get_model_conf(model_path, conf_path=None)
¶
Get model config information by reading a model config file (model.json).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
str |
Model path. |
required |
conf_path |
str |
Optional model config path. |
None |
Returns:
Type | Description |
---|---|
list[int, int, dict[str, Any]] |
Config information loaded from json file. |
Source code in adviser/tools/espnet_minimal/asr/asr_utils.py
def get_model_conf(model_path, conf_path=None):
"""Get model config information by reading a model config file (model.json).
Args:
model_path (str): Model path.
conf_path (str): Optional model config path.
Returns:
list[int, int, dict[str, Any]]: Config information loaded from json file.
"""
if conf_path is None:
model_conf = os.path.dirname(model_path) + '/model.json'
else:
model_conf = conf_path
with open(model_conf, "rb") as f:
logging.info('reading a config file from ' + model_conf)
confs = json.load(f)
if isinstance(confs, dict):
# for lm
args = confs
return argparse.Namespace(**args)
else:
# for asr, tts, mt
idim, odim, args = confs
return idim, odim, argparse.Namespace(**args)
torch_load(path, model)
¶
Load torch model states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Model path or snapshot file path to be loaded. |
required |
model |
torch.nn.Module |
Torch model. |
required |
Source code in adviser/tools/espnet_minimal/asr/asr_utils.py
def torch_load(path, model):
"""Load torch model states.
Args:
path (str): Model path or snapshot file path to be loaded.
model (torch.nn.Module): Torch model.
"""
if 'snapshot' in path:
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)['model']
else:
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
# debugging:
# print(model_state_dict)
if hasattr(model, 'module'):
model.module.load_state_dict(model_state_dict)
else:
model.load_state_dict(model_state_dict)
del model_state_dict
torch_save(path, model)
¶
Save torch model states.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
Model path to be saved. |
required |
model |
torch.nn.Module |
Torch model. |
required |
Source code in adviser/tools/espnet_minimal/asr/asr_utils.py
pytorch_backend
special
¶
asr_init
¶
Finetuning methods.
filter_modules(model_state_dict, modules)
¶
Filter non-matched modules in module_state_dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_state_dict |
OrderedDict |
trained model state_dict |
required |
modules |
list |
specified module list for transfer |
required |
Returns:
Type | Description |
---|---|
new_mods (list) |
the update module list |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in module_state_dict.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_mods (list): the update module list
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.warning("module(s) %s don\'t match or (partially match) "
"available modules in model.", incorrect_mods)
logging.warning('for information, the existing modules in model are:')
logging.warning('%s', mods_model)
return new_mods
get_partial_asr_mt_state_dict(model_state_dict, modules)
¶
Create state_dict with specified modules matching input model modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_state_dict |
OrderedDict |
trained model state_dict |
required |
modules |
list |
specified module list for transfer |
required |
Returns:
Type | Description |
---|---|
new_state_dict (OrderedDict) |
the updated state_dict |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def get_partial_asr_mt_state_dict(model_state_dict, modules):
"""Create state_dict with specified modules matching input model modules.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (OrderedDict): the updated state_dict
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
get_partial_lm_state_dict(model_state_dict, modules)
¶
Create compatible ASR state_dict from model_state_dict (LM).
The keys for specified modules are modified to match ASR decoder modules keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_state_dict |
OrderedDict |
trained model state_dict |
required |
modules |
list |
specified module list for transfer |
required |
Returns:
Type | Description |
---|---|
new_state_dict (OrderedDict) |
the updated state_dict new_mods (list): the updated module list |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def get_partial_lm_state_dict(model_state_dict, modules):
"""Create compatible ASR state_dict from model_state_dict (LM).
The keys for specified modules are modified to match ASR decoder modules keys.
Args:
model_state_dict (OrderedDict): trained model state_dict
modules (list): specified module list for transfer
Return:
new_state_dict (OrderedDict): the updated state_dict
new_mods (list): the updated module list
"""
new_state_dict = OrderedDict()
new_modules = []
for key, value in list(model_state_dict.items()):
if key == "predictor.embed.weight" and "predictor.embed." in modules:
new_key = "dec.embed.weight"
new_state_dict[new_key] = value
new_modules += [new_key]
elif "predictor.rnn." in key and "predictor.rnn." in modules:
new_key = "dec.decoder." + key.split("predictor.rnn.", 1)[1]
new_state_dict[new_key] = value
new_modules += [new_key]
return new_state_dict, new_modules
get_root_dir()
¶
get_trained_model_state_dict(model_path)
¶
Extract the trained model state dict for pre-initialization.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
str |
Path to model.***.best |
required |
Returns:
Type | Description |
---|---|
model.state_dict() (OrderedDict) |
the loaded model state_dict (str): Type of model. Either ASR/MT or LM. |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def get_trained_model_state_dict(model_path):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to model.***.best
Return:
model.state_dict() (OrderedDict): the loaded model state_dict
(str): Type of model. Either ASR/MT or LM.
"""
conf_path = os.path.join(os.path.dirname(model_path), 'model.json')
if 'rnnlm' in model_path:
logging.warning('reading model parameters from %s', model_path)
return torch.load(model_path), 'lm'
idim, odim, args = get_model_conf(model_path, conf_path)
logging.warning('reading model parameters from ' + model_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "services.hci.speech.espnet_minimal.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert isinstance(model, MTInterface) or isinstance(model, ASRInterface)
return model.state_dict(), 'asr-mt'
load_trained_model(model_path)
¶
Load the trained model for recognition.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_path |
str |
Path to model.***.best |
required |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def load_trained_model(model_path):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(model_path, os.path.join(get_root_dir(), os.path.dirname(model_path), 'model.json'))
# logging.warning('reading model parameters from ' + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "services.hci.speech.espnet_minimal.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
load_trained_modules(idim, odim, args, interface=<class 'tools.espnet_minimal.nets.asr_interface.ASRInterface'>)
¶
Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
initial input dimension. |
required |
odim |
int |
initial output dimension. |
required |
args |
Namespace |
The initial model arguments. |
required |
interface |
Interface |
ASRInterface or STInterface |
<class 'tools.espnet_minimal.nets.asr_interface.ASRInterface'> |
Returns:
Type | Description |
---|---|
model (torch.nn.Module) |
The model with pretrained modules. |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load model encoder or/and decoder modules with ESPNET pre-trained model(s).
Args:
idim (int): initial input dimension.
odim (int): initial output dimension.
args (Namespace): The initial model arguments.
interface (Interface): ASRInterface or STInterface
Return:
model (torch.nn.Module): The model with pretrained modules.
"""
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning('model(s) found for pre-initialization')
for model_path, modules in [(enc_model_path, enc_modules),
(dec_model_path, dec_modules)]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict, mode = get_trained_model_state_dict(model_path)
modules = filter_modules(model_state_dict, modules)
if mode == 'lm':
partial_state_dict, modules = get_partial_lm_state_dict(model_state_dict, modules)
else:
partial_state_dict = get_partial_asr_mt_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(main_state_dict, partial_state_dict, modules):
logging.warning('loading %s from model: %s', modules, model_path)
for k in partial_state_dict.keys():
logging.warning('override %s' % k)
main_state_dict.update(partial_state_dict)
else:
logging.warning('modules %s in model %s don\'t match your training config',
modules, model_path)
else:
logging.warning('model was not found : %s', model_path)
main_model.load_state_dict(main_state_dict)
return main_model
transfer_verification(model_state_dict, partial_state_dict, modules)
¶
Verify tuples (key, shape) for input model modules match specified modules.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_state_dict |
OrderedDict |
the initial model state_dict |
required |
partial_state_dict |
OrderedDict |
the trained model state_dict |
required |
modules |
list |
specified module list for transfer |
required |
Returns:
Type | Description |
---|---|
(boolean) |
allow transfer |
Source code in adviser/tools/espnet_minimal/asr/pytorch_backend/asr_init.py
def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (OrderedDict): the initial model state_dict
partial_state_dict (OrderedDict): the trained model state_dict
modules (list): specified module list for transfer
Return:
(boolean): allow transfer
"""
partial_modules = []
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
if value_p.shape == model_state_dict[key_p].shape:
partial_modules += [(key_p, value_p.shape)]
return len(partial_modules) > 0
bin
special
¶
asr_recog
¶
End-to-end speech recognition model decoding script.
get_parser()
¶
Get default arguments.
Source code in adviser/tools/espnet_minimal/bin/asr_recog.py
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description='Transcribe text from speech using a speech recognition model on one CPU or GPU',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True,
help='Config file path')
parser.add('--config2', is_config_file=True,
help='Second config file path that overwrites the settings in `--config`')
parser.add('--config3', is_config_file=True,
help='Third config file path that overwrites the settings in `--config` and `--config2`')
parser.add_argument('--ngpu', type=int, default=0,
help='Number of GPUs')
parser.add_argument('--dtype', choices=("float16", "float32", "float64"), default="float32",
help='Float precision (only available in --api v2)')
parser.add_argument('--backend', type=str, default='chainer',
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--debugmode', type=int, default=1,
help='Debugmode')
parser.add_argument('--seed', type=int, default=1,
help='Random seed')
parser.add_argument('--verbose', '-V', type=int, default=1,
help='Verbose option')
parser.add_argument('--batchsize', type=int, default=1,
help='Batch size for beam search (0: means no batch processing)')
parser.add_argument('--preprocess-conf', type=str, default=None,
help='The configuration file for the pre-processing')
parser.add_argument('--api', default="v1", choices=["v1", "v2"],
help='''Beam search APIs
v1: Default API. It only supports the ASRInterface.recognize method and DefaultRNNLM.
v2: Experimental API. It supports any models that implements ScorerInterface.''')
# task related
parser.add_argument('--recog-json', type=str,
help='Filename of recognition data (json)')
parser.add_argument('--result-label', type=str, required=True,
help='Filename of result label data (json)')
# model (parameter) related
parser.add_argument('--model', type=str, required=True,
help='Model file parameters to read')
parser.add_argument('--model-conf', type=str, default=None,
help='Model config file')
parser.add_argument('--num-spkrs', type=int, default=1,
choices=[1, 2],
help='Number of speakers in the speech')
parser.add_argument('--num-encs', default=1, type=int,
help='Number of encoders in the model.')
# search related
parser.add_argument('--nbest', type=int, default=1,
help='Output N-best hypotheses')
parser.add_argument('--beam-size', type=int, default=1,
help='Beam size')
parser.add_argument('--penalty', type=float, default=0.0,
help='Incertion penalty')
parser.add_argument('--maxlenratio', type=float, default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""")
parser.add_argument('--minlenratio', type=float, default=0.0,
help='Input length ratio to obtain min output length')
parser.add_argument('--ctc-weight', type=float, default=0.0,
help='CTC weight in joint decoding')
parser.add_argument('--weights-ctc-dec', type=float, action='append',
help='ctc weight assigned to each encoder during decoding.[in multi-encoder mode only]')
parser.add_argument('--ctc-window-margin', type=int, default=0,
help="""Use CTC window with margin parameter to accelerate
CTC/attention decoding especially on GPU. Smaller magin
makes decoding faster, but may increase search errors.
If margin=0 (default), this function is disabled""")
# transducer related
parser.add_argument('--score-norm-transducer', type=strtobool, nargs='?',
default=True,
help='Normalize transducer scores by length')
# rnnlm related
parser.add_argument('--rnnlm', type=str, default=None,
help='RNNLM model file to read')
parser.add_argument('--rnnlm-conf', type=str, default=None,
help='RNNLM model config file to read')
parser.add_argument('--word-rnnlm', type=str, default=None,
help='Word RNNLM model file to read')
parser.add_argument('--word-rnnlm-conf', type=str, default=None,
help='Word RNNLM model config file to read')
parser.add_argument('--word-dict', type=str, default=None,
help='Word list to read')
parser.add_argument('--lm-weight', type=float, default=0.1,
help='RNNLM weight')
# streaming related
parser.add_argument('--streaming-mode', type=str, default=None,
choices=['window', 'segment'],
help="""Use streaming recognizer for inference.
`--batchsize` must be set to 0 to enable this mode""")
parser.add_argument('--streaming-window', type=int, default=10,
help='Window size')
parser.add_argument('--streaming-min-blank-dur', type=int, default=10,
help='Minimum blank duration threshold')
parser.add_argument('--streaming-onset-margin', type=int, default=1,
help='Onset margin')
parser.add_argument('--streaming-offset-margin', type=int, default=1,
help='Offset margin')
return parser
main(args)
¶
Run the main decoding function.
Source code in adviser/tools/espnet_minimal/bin/asr_recog.py
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args = parser.parse_args(args)
if args.ngpu == 0 and args.dtype == "float16":
raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.")
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
elif args.verbose == 2:
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info('set random seed = %d' % args.seed)
# validate rnn options
if args.rnnlm is not None and args.word_rnnlm is not None:
logging.error("It seems that both --rnnlm and --word-rnnlm are specified. Please use either option.")
sys.exit(1)
# recog
logging.info('backend = ' + args.backend)
if args.num_spkrs == 1:
if args.backend == "chainer":
from tools.espnet_minimal.asr.chainer_backend.asr import recog
recog(args)
elif args.backend == "pytorch":
if args.num_encs == 1:
# Experimental API that supports custom LMs
if args.api == "v2":
from tools.espnet_minimal.asr.pytorch_backend.recog import recog_v2
recog_v2(args)
else:
from tools.espnet_minimal.asr.pytorch_backend.asr import recog
if args.dtype != "float32":
raise NotImplementedError(f"`--dtype {args.dtype}` is only available with `--api v2`")
recog(args)
else:
if args.api == "v2":
raise NotImplementedError(f"--num-encs {args.num_encs} > 1 is not supported in --api v2")
else:
from tools.espnet_minimal.asr.pytorch_backend.asr import recog
recog(args)
else:
raise ValueError("Only chainer and pytorch are supported.")
elif args.num_spkrs == 2:
if args.backend == "pytorch":
from tools.espnet_minimal.asr.pytorch_backend.asr_mix import recog
recog(args)
else:
raise ValueError("Only pytorch is supported.")
asr_train
¶
Automatic speech recognition model training script.
is_torch_1_2_plus
¶
get_parser(parser=None, required=True)
¶
Get default arguments.
Source code in adviser/tools/espnet_minimal/bin/asr_train.py
def get_parser(parser=None, required=True):
"""Get default arguments."""
if parser is None:
parser = configargparse.ArgumentParser(
description="Train an automatic speech recognition (ASR) model on one CPU, one or multiple GPUs",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True, help='config file path')
parser.add('--config2', is_config_file=True,
help='second config file path that overwrites the settings in `--config`.')
parser.add('--config3', is_config_file=True,
help='third config file path that overwrites the settings in `--config` and `--config2`.')
parser.add_argument('--ngpu', default=None, type=int,
help='Number of GPUs. If not given, use all visible devices')
parser.add_argument('--train-dtype', default="float32",
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
help='Data type for training (only pytorch backend). '
'O0,O1,.. flags require apex. See https://nvidia.github.io/apex/amp.html#opt-levels')
parser.add_argument('--backend', default='chainer', type=str,
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--outdir', type=str, required=required,
help='Output directory')
parser.add_argument('--debugmode', default=1, type=int,
help='Debugmode')
parser.add_argument('--dict', required=required,
help='Dictionary')
parser.add_argument('--seed', default=1, type=int,
help='Random seed')
parser.add_argument('--debugdir', type=str,
help='Output directory for debugging')
parser.add_argument('--resume', '-r', default='', nargs='?',
help='Resume the training from snapshot')
parser.add_argument('--minibatches', '-N', type=int, default='-1',
help='Process only N minibatches (for debug)')
parser.add_argument('--verbose', '-V', default=0, type=int,
help='Verbose option')
parser.add_argument('--tensorboard-dir', default=None, type=str, nargs='?', help="Tensorboard log dir path")
parser.add_argument('--report-interval-iters', default=100, type=int,
help="Report interval iterations")
parser.add_argument('--save-interval-iters', default=0, type=int,
help="Save snapshot interval iterations")
# task related
parser.add_argument('--train-json', type=str, default=None,
help='Filename of train label data (json)')
parser.add_argument('--valid-json', type=str, default=None,
help='Filename of validation label data (json)')
# network architecture
parser.add_argument('--model-module', type=str, default=None,
help='model defined module (default: services.hci.speech.espnet_minimal.nets.xxx_backend.e2e_asr:E2E)')
# encoder
parser.add_argument('--num-encs', default=1, type=int,
help='Number of encoders in the model.')
# loss related
parser.add_argument('--ctc_type', default='warpctc', type=str,
choices=['builtin', 'warpctc'],
help='Type of CTC implementation to calculate loss.')
parser.add_argument('--mtlalpha', default=0.5, type=float,
help='Multitask learning coefficient, alpha: alpha*ctc_loss + (1-alpha)*att_loss ')
parser.add_argument('--lsm-weight', default=0.0, type=float,
help='Label smoothing weight')
# recognition options to compute CER/WER
parser.add_argument('--report-cer', default=False, action='store_true',
help='Compute CER on development set')
parser.add_argument('--report-wer', default=False, action='store_true',
help='Compute WER on development set')
parser.add_argument('--nbest', type=int, default=1,
help='Output N-best hypotheses')
parser.add_argument('--beam-size', type=int, default=4,
help='Beam size')
parser.add_argument('--penalty', default=0.0, type=float,
help='Incertion penalty')
parser.add_argument('--maxlenratio', default=0.0, type=float,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""")
parser.add_argument('--minlenratio', default=0.0, type=float,
help='Input length ratio to obtain min output length')
parser.add_argument('--ctc-weight', default=0.3, type=float,
help='CTC weight in joint decoding')
parser.add_argument('--rnnlm', type=str, default=None,
help='RNNLM model file to read')
parser.add_argument('--rnnlm-conf', type=str, default=None,
help='RNNLM model config file to read')
parser.add_argument('--lm-weight', default=0.1, type=float,
help='RNNLM weight.')
parser.add_argument('--sym-space', default='<space>', type=str,
help='Space symbol')
parser.add_argument('--sym-blank', default='<blank>', type=str,
help='Blank symbol')
# minibatch related
parser.add_argument('--sortagrad', default=0, type=int, nargs='?',
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs")
parser.add_argument('--batch-count', default='auto', choices=BATCH_COUNT_CHOICES,
help='How to count batch_size. The default (auto) will find how to count by args.')
parser.add_argument('--batch-size', '--batch-seqs', '-b', default=0, type=int,
help='Maximum seqs in a minibatch (0 to disable)')
parser.add_argument('--batch-bins', default=0, type=int,
help='Maximum bins in a minibatch (0 to disable)')
parser.add_argument('--batch-frames-in', default=0, type=int,
help='Maximum input frames in a minibatch (0 to disable)')
parser.add_argument('--batch-frames-out', default=0, type=int,
help='Maximum output frames in a minibatch (0 to disable)')
parser.add_argument('--batch-frames-inout', default=0, type=int,
help='Maximum input+output frames in a minibatch (0 to disable)')
parser.add_argument('--maxlen-in', '--batch-seq-maxlen-in', default=800, type=int, metavar='ML',
help='When --batch-count=seq, batch size is reduced if the input sequence length > ML.')
parser.add_argument('--maxlen-out', '--batch-seq-maxlen-out', default=150, type=int, metavar='ML',
help='When --batch-count=seq, batch size is reduced if the output sequence length > ML')
parser.add_argument('--n-iter-processes', default=0, type=int,
help='Number of processes of iterator')
parser.add_argument('--preprocess-conf', type=str, default=None, nargs='?',
help='The configuration file for the pre-processing')
# optimization related
parser.add_argument('--opt', default='adadelta', type=str,
choices=['adadelta', 'adam', 'noam'],
help='Optimizer')
parser.add_argument('--accum-grad', default=1, type=int,
help='Number of gradient accumuration')
parser.add_argument('--eps', default=1e-8, type=float,
help='Epsilon constant for optimizer')
parser.add_argument('--eps-decay', default=0.01, type=float,
help='Decaying ratio of epsilon')
parser.add_argument('--weight-decay', default=0.0, type=float,
help='Weight decay ratio')
parser.add_argument('--criterion', default='acc', type=str,
choices=['loss', 'acc'],
help='Criterion to perform epsilon decay')
parser.add_argument('--threshold', default=1e-4, type=float,
help='Threshold to stop iteration')
parser.add_argument('--epochs', '-e', default=30, type=int,
help='Maximum number of epochs')
parser.add_argument('--early-stop-criterion', default='validation/main/acc', type=str, nargs='?',
help="Value to monitor to trigger an early stopping of the training")
parser.add_argument('--patience', default=3, type=int, nargs='?',
help="Number of epochs to wait without improvement before stopping the training")
parser.add_argument('--grad-clip', default=5, type=float,
help='Gradient norm threshold to clip')
parser.add_argument('--num-save-attention', default=3, type=int,
help='Number of samples of attention to be saved')
parser.add_argument('--grad-noise', type=strtobool, default=False,
help='The flag to switch to use noise injection to gradients during training')
# asr_mix related
parser.add_argument('--num-spkrs', default=1, type=int,
choices=[1, 2],
help='Number of speakers in the speech.')
# decoder related
parser.add_argument('--context-residual', default=False, type=strtobool, nargs='?',
help='The flag to switch to use context vector residual in the decoder network')
# finetuning related
parser.add_argument('--enc-init', default=None, type=str,
help='Pre-trained ASR model to initialize encoder.')
parser.add_argument('--enc-init-mods', default='enc.enc.',
type=lambda s: [str(mod) for mod in s.split(',') if s != ''],
help='List of encoder modules to initialize, separated by a comma.')
parser.add_argument('--dec-init', default=None, type=str,
help='Pre-trained ASR, MT or LM model to initialize decoder.')
parser.add_argument('--dec-init-mods', default='att., dec.',
type=lambda s: [str(mod) for mod in s.split(',') if s != ''],
help='List of decoder modules to initialize, separated by a comma.')
# front end related
parser.add_argument('--use-frontend', type=strtobool, default=False,
help='The flag to switch to use frontend system.')
# WPE related
parser.add_argument('--use-wpe', type=strtobool, default=False,
help='Apply Weighted Prediction Error')
parser.add_argument('--wtype', default='blstmp', type=str,
choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm',
'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'],
help='Type of encoder network architecture '
'of the mask estimator for WPE. '
'')
parser.add_argument('--wlayers', type=int, default=2,
help='')
parser.add_argument('--wunits', type=int, default=300,
help='')
parser.add_argument('--wprojs', type=int, default=300,
help='')
parser.add_argument('--wdropout-rate', type=float, default=0.0,
help='')
parser.add_argument('--wpe-taps', type=int, default=5,
help='')
parser.add_argument('--wpe-delay', type=int, default=3,
help='')
parser.add_argument('--use-dnn-mask-for-wpe', type=strtobool,
default=False,
help='Use DNN to estimate the power spectrogram. '
'This option is experimental.')
# Beamformer related
parser.add_argument('--use-beamformer', type=strtobool,
default=True, help='')
parser.add_argument('--btype', default='blstmp', type=str,
choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm',
'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'],
help='Type of encoder network architecture '
'of the mask estimator for Beamformer.')
parser.add_argument('--blayers', type=int, default=2,
help='')
parser.add_argument('--bunits', type=int, default=300,
help='')
parser.add_argument('--bprojs', type=int, default=300,
help='')
parser.add_argument('--badim', type=int, default=320,
help='')
parser.add_argument('--bnmask', type=int, default=2,
help='Number of beamforming masks, '
'default is 2 for [speech, noise].')
parser.add_argument('--ref-channel', type=int, default=-1,
help='The reference channel used for beamformer. '
'By default, the channel is estimated by DNN.')
parser.add_argument('--bdropout-rate', type=float, default=0.0,
help='')
# Feature transform: Normalization
parser.add_argument('--stats-file', type=str, default=None,
help='The stats file for the feature normalization')
parser.add_argument('--apply-uttmvn', type=strtobool, default=True,
help='Apply utterance level mean '
'variance normalization.')
parser.add_argument('--uttmvn-norm-means', type=strtobool,
default=True, help='')
parser.add_argument('--uttmvn-norm-vars', type=strtobool, default=False,
help='')
# Feature transform: Fbank
parser.add_argument('--fbank-fs', type=int, default=16000,
help='The sample frequency used for '
'the mel-fbank creation.')
parser.add_argument('--n-mels', type=int, default=80,
help='The number of mel-frequency bins.')
parser.add_argument('--fbank-fmin', type=float, default=0.,
help='')
parser.add_argument('--fbank-fmax', type=float, default=None,
help='')
return parser
main(cmd_args)
¶
Run the main training function.
Source code in adviser/tools/espnet_minimal/bin/asr_train.py
def main(cmd_args):
"""Run the main training function."""
parser = get_parser()
args, _ = parser.parse_known_args(cmd_args)
if args.backend == "chainer" and args.train_dtype != "float32":
raise NotImplementedError(
f"chainer backend does not support --train-dtype {args.train_dtype}."
"Use --dtype float32.")
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
raise ValueError(f"--train-dtype {args.train_dtype} does not support the CPU backend.")
from tools.espnet_minimal import dynamic_import
if args.model_module is None:
model_module = "services.hci.speech.espnet_minimal.nets." + args.backend + "_backend.e2e_asr:E2E"
else:
model_module = args.model_module
model_class = dynamic_import(model_module)
model_class.add_arguments(parser)
args = parser.parse_args(cmd_args)
args.model_module = model_module
if 'chainer_backend' in args.model_module:
args.backend = 'chainer'
if 'pytorch_backend' in args.model_module:
args.backend = 'pytorch'
# logging info
if args.verbose > 0:
logging.basicConfig(
level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
else:
logging.basicConfig(
level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
logging.warning('Skip DEBUG/INFO messages')
# If --ngpu is not given,
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
# 2. if nvidia-smi exists, use all devices
# 3. else ngpu=0
if args.ngpu is None:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is not None:
ngpu = len(cvd.split(','))
else:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
try:
p = subprocess.run(['nvidia-smi', '-L'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
except (subprocess.CalledProcessError, FileNotFoundError):
ngpu = 0
else:
ngpu = len(p.stderr.decode().split('\n')) - 1
else:
if is_torch_1_2_plus:
assert args.ngpu == 1, "There are some bugs with multi-GPU processing in PyTorch 1.2+" \
" (see https://github.com/pytorch/pytorch/issues/21108)"
ngpu = args.ngpu
logging.info(f"ngpu: {ngpu}")
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# set random seed
logging.info('random seed = %d' % args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
# load dictionary for debug log
if args.dict is not None:
with open(args.dict, 'rb') as f:
dictionary = f.readlines()
char_list = [entry.decode('utf-8').split(' ')[0]
for entry in dictionary]
char_list.insert(0, '<blank>')
char_list.append('<eos>')
args.char_list = char_list
else:
args.char_list = None
# train
logging.info('backend = ' + args.backend)
if args.num_spkrs == 1:
if args.backend == "chainer":
from tools.espnet_minimal.asr.chainer_backend.asr import train
train(args)
elif args.backend == "pytorch":
from tools.espnet_minimal.asr.pytorch_backend.asr import train
train(args)
else:
raise ValueError("Only chainer and pytorch are supported.")
else:
# FIXME(kamo): Support --model-module
if args.backend == "pytorch":
from tools.espnet_minimal.asr.pytorch_backend.asr_mix import train
train(args)
else:
raise ValueError("Only pytorch is supported.")
lm_train
¶
Language model training script.
get_parser()
¶
Get parser.
Source code in adviser/tools/espnet_minimal/bin/lm_train.py
def get_parser():
"""Get parser."""
parser = configargparse.ArgumentParser(
description='Train a new language model on one CPU or one GPU',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True, help='config file path')
parser.add('--config2', is_config_file=True,
help='second config file path that overwrites the settings in `--config`.')
parser.add('--config3', is_config_file=True,
help='third config file path that overwrites the settings in `--config` and `--config2`.')
parser.add_argument('--ngpu', default=None, type=int,
help='Number of GPUs. If not given, use all visible devices')
parser.add_argument('--train-dtype', default="float32",
choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"],
help='Data type for training (only pytorch backend). '
'O0,O1,.. flags require apex. See https://nvidia.github.io/apex/amp.html#opt-levels')
parser.add_argument('--backend', default='chainer', type=str,
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--outdir', type=str, required=True,
help='Output directory')
parser.add_argument('--debugmode', default=1, type=int,
help='Debugmode')
parser.add_argument('--dict', type=str, required=True,
help='Dictionary')
parser.add_argument('--seed', default=1, type=int,
help='Random seed')
parser.add_argument('--resume', '-r', default='', nargs='?',
help='Resume the training from snapshot')
parser.add_argument('--verbose', '-V', default=0, type=int,
help='Verbose option')
parser.add_argument('--tensorboard-dir', default=None, type=str, nargs='?', help="Tensorboard log dir path")
parser.add_argument('--report-interval-iters', default=100, type=int,
help="Report interval iterations")
# task related
parser.add_argument('--train-label', type=str, required=True,
help='Filename of train label data')
parser.add_argument('--valid-label', type=str, required=True,
help='Filename of validation label data')
parser.add_argument('--test-label', type=str,
help='Filename of test label data')
parser.add_argument('--dump-hdf5-path', type=str, default=None,
help='Path to dump a preprocessed dataset as hdf5')
# training configuration
parser.add_argument('--opt', default='sgd', type=str,
choices=['sgd', 'adam'],
help='Optimizer')
parser.add_argument('--sortagrad', default=0, type=int, nargs='?',
help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs")
parser.add_argument('--batchsize', '-b', type=int, default=300,
help='Number of examples in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--early-stop-criterion', default='validation/main/loss', type=str, nargs='?',
help="Value to monitor to trigger an early stopping of the training")
parser.add_argument('--patience', default=3, type=int, nargs='?',
help="Number of epochs to wait without improvement before stopping the training")
parser.add_argument('--gradclip', '-c', type=float, default=5,
help='Gradient norm threshold to clip')
parser.add_argument('--maxlen', type=int, default=40,
help='Batch size is reduced if the input sequence > ML')
parser.add_argument('--model-module', type=str, default='default',
help='model defined module (default: services.hci.speech.espnet_minimal.nets.xxx_backend.lm.default:DefaultRNNLM)')
return parser
main(cmd_args)
¶
Train LM.
Source code in adviser/tools/espnet_minimal/bin/lm_train.py
def main(cmd_args):
"""Train LM."""
parser = get_parser()
args, _ = parser.parse_known_args(cmd_args)
if args.backend == "chainer" and args.train_dtype != "float32":
raise NotImplementedError(
f"chainer backend does not support --train-dtype {args.train_dtype}."
"Use --dtype float32.")
if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
raise ValueError(f"--train-dtype {args.train_dtype} does not support the CPU backend.")
# parse model-specific arguments dynamically
model_class = dynamic_import_lm(args.model_module, args.backend)
model_class.add_arguments(parser)
args = parser.parse_args(cmd_args)
# logging info
if args.verbose > 0:
logging.basicConfig(
level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
else:
logging.basicConfig(
level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
logging.warning('Skip DEBUG/INFO messages')
# If --ngpu is not given,
# 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
# 2. if nvidia-smi exists, use all devices
# 3. else ngpu=0
if args.ngpu is None:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is not None:
ngpu = len(cvd.split(','))
else:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
try:
p = subprocess.run(['nvidia-smi', '-L'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
except (subprocess.CalledProcessError, FileNotFoundError):
ngpu = 0
else:
ngpu = len(p.stderr.decode().split('\n')) - 1
else:
ngpu = args.ngpu
logging.info(f"ngpu: {ngpu}")
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# seed setting
nseed = args.seed
random.seed(nseed)
np.random.seed(nseed)
# load dictionary
with open(args.dict, 'rb') as f:
dictionary = f.readlines()
char_list = [entry.decode('utf-8').split(' ')[0] for entry in dictionary]
char_list.insert(0, '<blank>')
char_list.append('<eos>')
args.char_list_dict = {x: i for i, x in enumerate(char_list)}
args.n_vocab = len(char_list)
# train
logging.info('backend = ' + args.backend)
if args.backend == "chainer":
from tools.espnet_minimal import train
train(args)
elif args.backend == "pytorch":
from tools.espnet_minimal.lm.pytorch_backend.lm import train
train(args)
else:
raise ValueError("Only chainer and pytorch are supported.")
mt_trans
¶
Neural machine translation model decoding script.
get_parser()
¶
Get default arguments.
Source code in adviser/tools/espnet_minimal/bin/mt_trans.py
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description='Translate text from speech using a speech translation model on one CPU or GPU',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True,
help='Config file path')
parser.add('--config2', is_config_file=True,
help='Second config file path that overwrites the settings in `--config`')
parser.add('--config3', is_config_file=True,
help='Third config file path that overwrites the settings in `--config` and `--config2`')
parser.add_argument('--ngpu', type=int, default=0,
help='Number of GPUs')
parser.add_argument('--dtype', choices=("float16", "float32", "float64"), default="float32",
help='Float precision (only available in --api v2)')
parser.add_argument('--backend', type=str, default='chainer',
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--debugmode', type=int, default=1,
help='Debugmode')
parser.add_argument('--seed', type=int, default=1,
help='Random seed')
parser.add_argument('--verbose', '-V', type=int, default=1,
help='Verbose option')
parser.add_argument('--batchsize', type=int, default=1,
help='Batch size for beam search (0: means no batch processing)')
parser.add_argument('--preprocess-conf', type=str, default=None,
help='The configuration file for the pre-processing')
parser.add_argument('--api', default="v1", choices=["v1", "v2"],
help='''Beam search APIs
v1: Default API. It only supports the ASRInterface.recognize method and DefaultRNNLM.
v2: Experimental API. It supports any models that implements ScorerInterface.''')
# task related
parser.add_argument('--trans-json', type=str,
help='Filename of translation data (json)')
parser.add_argument('--result-label', type=str, required=True,
help='Filename of result label data (json)')
# model (parameter) related
parser.add_argument('--model', type=str, required=True,
help='Model file parameters to read')
parser.add_argument('--model-conf', type=str, default=None,
help='Model config file')
# search related
parser.add_argument('--nbest', type=int, default=1,
help='Output N-best hypotheses')
parser.add_argument('--beam-size', type=int, default=1,
help='Beam size')
parser.add_argument('--penalty', type=float, default=0.1,
help='Incertion penalty')
parser.add_argument('--maxlenratio', type=float, default=3.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""")
parser.add_argument('--minlenratio', type=float, default=0.0,
help='Input length ratio to obtain min output length')
# rnnlm related
parser.add_argument('--rnnlm', type=str, default=None,
help='RNNLM model file to read')
parser.add_argument('--rnnlm-conf', type=str, default=None,
help='RNNLM model config file to read')
parser.add_argument('--lm-weight', type=float, default=0.0,
help='RNNLM weight')
# multilingual related
parser.add_argument('--tgt-lang', default=False, type=str,
help='target language ID (e.g., <en>, <de>, and <fr> etc.)')
return parser
main(args)
¶
Run the main decoding function.
Source code in adviser/tools/espnet_minimal/bin/mt_trans.py
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
elif args.verbose == 2:
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info('set random seed = %d' % args.seed)
# trans
logging.info('backend = ' + args.backend)
if args.backend == "pytorch":
# Experimental API that supports custom LMs
from tools.espnet_minimal.mt.pytorch_backend.mt import trans
if args.dtype != "float32":
raise NotImplementedError(f"`--dtype {args.dtype}` is only available with `--api v2`")
trans(args)
else:
raise ValueError("Only pytorch are supported.")
st_trans
¶
End-to-end speech translation model decoding script.
get_parser()
¶
Get default arguments.
Source code in adviser/tools/espnet_minimal/bin/st_trans.py
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description='Translate text from speech using a speech translation model on one CPU or GPU',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True,
help='Config file path')
parser.add('--config2', is_config_file=True,
help='Second config file path that overwrites the settings in `--config`')
parser.add('--config3', is_config_file=True,
help='Third config file path that overwrites the settings in `--config` and `--config2`')
parser.add_argument('--ngpu', type=int, default=0,
help='Number of GPUs')
parser.add_argument('--dtype', choices=("float16", "float32", "float64"), default="float32",
help='Float precision (only available in --api v2)')
parser.add_argument('--backend', type=str, default='chainer',
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--debugmode', type=int, default=1,
help='Debugmode')
parser.add_argument('--seed', type=int, default=1,
help='Random seed')
parser.add_argument('--verbose', '-V', type=int, default=1,
help='Verbose option')
parser.add_argument('--batchsize', type=int, default=1,
help='Batch size for beam search (0: means no batch processing)')
parser.add_argument('--preprocess-conf', type=str, default=None,
help='The configuration file for the pre-processing')
parser.add_argument('--api', default="v1", choices=["v1", "v2"],
help='''Beam search APIs
v1: Default API. It only supports the ASRInterface.recognize method and DefaultRNNLM.
v2: Experimental API. It supports any models that implements ScorerInterface.''')
# task related
parser.add_argument('--trans-json', type=str,
help='Filename of translation data (json)')
parser.add_argument('--result-label', type=str, required=True,
help='Filename of result label data (json)')
# model (parameter) related
parser.add_argument('--model', type=str, required=True,
help='Model file parameters to read')
# search related
parser.add_argument('--nbest', type=int, default=1,
help='Output N-best hypotheses')
parser.add_argument('--beam-size', type=int, default=1,
help='Beam size')
parser.add_argument('--penalty', type=float, default=0.0,
help='Incertion penalty')
parser.add_argument('--maxlenratio', type=float, default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""")
parser.add_argument('--minlenratio', type=float, default=0.0,
help='Input length ratio to obtain min output length')
# rnnlm related
parser.add_argument('--rnnlm', type=str, default=None,
help='RNNLM model file to read')
parser.add_argument('--rnnlm-conf', type=str, default=None,
help='RNNLM model config file to read')
parser.add_argument('--word-rnnlm', type=str, default=None,
help='Word RNNLM model file to read')
parser.add_argument('--word-rnnlm-conf', type=str, default=None,
help='Word RNNLM model config file to read')
parser.add_argument('--word-dict', type=str, default=None,
help='Word list to read')
parser.add_argument('--lm-weight', type=float, default=0.1,
help='RNNLM weight')
# multilingual related
parser.add_argument('--tgt-lang', default=False, type=str,
help='target language ID (e.g., <en>, <de>, and <fr> etc.)')
return parser
main(args)
¶
Run the main decoding function.
Source code in adviser/tools/espnet_minimal/bin/st_trans.py
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
elif args.verbose == 2:
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.WARN, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info('set random seed = %d' % args.seed)
# validate rnn options
if args.rnnlm is not None and args.word_rnnlm is not None:
logging.error("It seems that both --rnnlm and --word-rnnlm are specified. Please use either option.")
sys.exit(1)
# trans
logging.info('backend = ' + args.backend)
if args.backend == "pytorch":
# Experimental API that supports custom LMs
from tools.espnet_minimal import trans
if args.dtype != "float32":
raise NotImplementedError(f"`--dtype {args.dtype}` is only available with `--api v2`")
trans(args)
else:
raise ValueError("Only pytorch are supported.")
tts_decode
¶
TTS decoding script.
get_parser()
¶
Get parser of decoding arguments.
Source code in adviser/tools/espnet_minimal/bin/tts_decode.py
def get_parser():
"""Get parser of decoding arguments."""
parser = configargparse.ArgumentParser(
description='Synthesize speech from text using a TTS model on one CPU',
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
# general configuration
parser.add('--config', is_config_file=True, help='config file path')
parser.add('--config2', is_config_file=True,
help='second config file path that overwrites the settings in `--config`.')
parser.add('--config3', is_config_file=True,
help='third config file path that overwrites the settings in `--config` and `--config2`.')
parser.add_argument('--ngpu', default=0, type=int,
help='Number of GPUs')
parser.add_argument('--backend', default='pytorch', type=str,
choices=['chainer', 'pytorch'],
help='Backend library')
parser.add_argument('--debugmode', default=1, type=int,
help='Debugmode')
parser.add_argument('--seed', default=1, type=int,
help='Random seed')
parser.add_argument('--out', type=str, required=True,
help='Output filename')
parser.add_argument('--verbose', '-V', default=0, type=int,
help='Verbose option')
parser.add_argument('--preprocess-conf', type=str, default=None,
help='The configuration file for the pre-processing')
# task related
parser.add_argument('--json', type=str, required=True,
help='Filename of train label data (json)')
parser.add_argument('--model', type=str, required=True,
help='Model file parameters to read')
parser.add_argument('--model-conf', type=str, default=None,
help='Model config file')
# decoding related
parser.add_argument('--maxlenratio', type=float, default=5,
help='Maximum length ratio in decoding')
parser.add_argument('--minlenratio', type=float, default=0,
help='Minimum length ratio in decoding')
parser.add_argument('--threshold', type=float, default=0.5,
help='Threshold value in decoding')
parser.add_argument('--use-att-constraint', type=strtobool, default=False,
help='Whether to use the attention constraint')
parser.add_argument('--backward-window', type=int, default=1,
help='Backward window size in the attention constraint')
parser.add_argument('--forward-window', type=int, default=3,
help='Forward window size in the attention constraint')
# save related
parser.add_argument('--save-durations', default=False, type=strtobool,
help='Whether to save durations converted from attentions')
parser.add_argument('--save-focus-rates', default=False, type=strtobool,
help='Whether to save focus rates of attentions')
return parser
main(args)
¶
Run deocding.
Source code in adviser/tools/espnet_minimal/bin/tts_decode.py
def main(args):
"""Run deocding."""
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose > 0:
logging.basicConfig(
level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
else:
logging.basicConfig(
level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
logging.warning('Skip DEBUG/INFO messages')
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
# python 2 case
if platform.python_version_tuple()[0] == '2':
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
cvd = subprocess.check_output(["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]).strip()
logging.info('CLSP: use gpu' + cvd)
os.environ['CUDA_VISIBLE_DEVICES'] = cvd
# python 3 case
else:
if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode():
cvd = subprocess.check_output(["/usr/local/bin/free-gpu", "-n", str(args.ngpu)]).decode().strip()
logging.info('CLSP: use gpu' + cvd)
os.environ['CUDA_VISIBLE_DEVICES'] = cvd
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# display PYTHONPATH
logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))
# extract
logging.info('backend = ' + args.backend)
if args.backend == "pytorch":
from tools.espnet_minimal.tts.pytorch_backend.tts import decode
decode(args)
else:
raise NotImplementedError("Only pytorch is supported.")
nets
special
¶
asr_interface
¶
ASR Interface module.
predefined_asr
¶
ASRInterface
¶
ASR Interface for ESPnet model implementation.
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
class ASRInterface:
"""ASR Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("recognize method is not implemented")
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from tools.espnet_minimal.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
def encode(self, feat):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
torch.Tensor for pytorch, chainer.Variable for chainer:
encoded feature (T, D)
"""
raise NotImplementedError("encode method is not implemented")
def scorers(self):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise NotImplementedError("decoders method is not implemented")
attention_plot_class
property
readonly
¶
Get attention plot class.
add_arguments(parser)
staticmethod
¶
build(idim, odim, **kwargs)
classmethod
¶
Initialize this class with python-level args.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
The number of an input feature dim. |
required |
odim |
int |
The number of output vocab. |
required |
Returns:
Type | Description |
---|---|
ASRinterface |
A new instance of ASRInterface. |
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
calculate_all_attentions(self, xs, ilens, ys)
¶
Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...] :param ndarray ilens: batch of lengths of input sequences (B) :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] :return: attention weights (B, Lmax, Tmax) :rtype: float ndarray
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
def calculate_all_attentions(self, xs, ilens, ys):
"""Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
encode(self, feat)
¶
Encode feature in beam_search
(optional).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
numpy.ndarray |
input feature (T, D) |
required |
Returns:
Type | Description |
---|---|
torch.Tensor for pytorch, chainer.Variable for chainer |
encoded feature (T, D) |
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
forward(self, xs, ilens, ys)
¶
Compute loss for training.
:param xs: For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim) For chainer, list of source sequences chainer.Variable :param ilens: batch of lengths of source sequences (B) For pytorch, torch.Tensor For chainer, list of int :param ys: For pytorch, batch of padded source sequences torch.Tensor (B, Lmax) For chainer, list of source sequences chainer.Variable :return: loss value :rtype: torch.Tensor for pytorch, chainer.Variable for chainer
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
recognize(self, x, recog_args, char_list=None, rnnlm=None)
¶
Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("recognize method is not implemented")
recognize_batch(self, x, recog_args, char_list=None, rnnlm=None)
¶
Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc) :param namespace recog_args: argument namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
scorers(self)
¶
Get scorers for beam_search
(optional).
Returns:
Type | Description |
---|---|
dict[str, ScorerInterface] |
dict of |
dynamic_import_asr(module, backend)
¶
Import ASR models dynamically.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
str |
module_name:class_name or alias in |
required |
backend |
str |
NN backend. e.g., pytorch, chainer |
required |
Returns:
Type | Description |
---|---|
type |
ASR class |
Source code in adviser/tools/espnet_minimal/nets/asr_interface.py
def dynamic_import_asr(module, backend):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
backend (str): NN backend. e.g., pytorch, chainer
Returns:
type: ASR class
"""
model_class = dynamic_import(module, predefined_asr.get(backend, dict()))
assert issubclass(model_class, ASRInterface), f"{module} does not implement ASRInterface"
return model_class
batch_beam_search
¶
Parallel beam search module.
BatchBeamSearch (BeamSearch)
¶
Batch beam search implementation.
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
class BatchBeamSearch(BeamSearch):
"""Batch beam search implementation."""
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
"""Convert list to batch."""
if len(hyps) == 0:
return BatchHypothesis()
return BatchHypothesis(
yseq=pad_sequence([h.yseq for h in hyps], batch_first=True, padding_value=self.eos),
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
score=torch.tensor([h.score for h in hyps]),
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
states={k: [h.states[k] for h in hyps] for k in self.scorers}
)
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
return BatchHypothesis(
yseq=hyps.yseq[ids],
score=hyps.score[ids],
length=hyps.length[ids],
scores={k: v[ids] for k, v in hyps.scores.items()},
states={k: [self.scorers[k].select_state(v, i) for i in ids]
for k, v in hyps.states.items()},
)
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
return Hypothesis(
yseq=hyps.yseq[i, :hyps.length[i]],
score=hyps.score[i],
scores={k: v[i] for k, v in hyps.scores.items()},
states={k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()},
)
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
"""Revert batch to list."""
return [
Hypothesis(
yseq=batch_hyps.yseq[i][:batch_hyps.length[i]],
score=batch_hyps.score[i],
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
states={k: v.select_state(
batch_hyps.states[k], i) for k, v in self.scorers.items()}
) for i in range(len(batch_hyps.length))]
def batch_beam(self, weighted_scores: torch.Tensor, ids: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Batch-compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(n_beam, self.vocab_size)`.
ids (torch.Tensor): The partial token ids to compute topk.
Its shape is `(n_beam, self.pre_beam_size)`.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
The topk full (prev_hyp, new_token) ids and partial (prev_hyp, new_token) ids.
Their shapes are all `(self.beam_size,)`
"""
if not self.do_pre_beam:
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
# Because of the flatten above, `top_ids` is organized as:
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
# where V is `self.n_vocab` and K is `self.beam_size`
prev_hyp_ids = top_ids // self.n_vocab
new_token_ids = top_ids % self.n_vocab
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
raise NotImplementedError("batch decoding with PartialScorer is not supported yet.")
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
return self.batchfy(super().init_hyp(x))
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (BatchHypothesis): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
BatchHypothesis: Best sorted hypotheses
"""
n_batch = len(running_hyps)
# batch scoring
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
if self.do_pre_beam:
part_ids = torch.topk(scores[self.pre_beam_score_key], self.pre_beam_size, dim=-1)[1]
else:
part_ids = torch.arange(self.n_vocab, device=x.device).expand(n_batch, self.n_vocab)
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
# weighted sum scores
weighted_scores = torch.zeros(n_batch, self.n_vocab, dtype=x.dtype, device=x.device)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
weighted_scores += running_hyps.score.to(x.device).unsqueeze(1)
# TODO(karita): do not use list. use batch instead
# update hyps
best_hyps = []
prev_hyps = self.unbatchfy(running_hyps)
for full_prev_hyp_id, full_new_token_id, part_prev_hyp_id, part_new_token_id in zip(
*self.batch_beam(weighted_scores, part_ids)):
prev_hyp = prev_hyps[full_prev_hyp_id]
best_hyps.append(Hypothesis(
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
scores=self.merge_scores(
prev_hyp.scores,
{k: v[full_prev_hyp_id] for k, v in scores.items()}, full_new_token_id,
{k: v[part_prev_hyp_id] for k, v in part_scores.items()}, part_new_token_id),
states=self.merge_states(
{k: self.full_scorers[k].select_state(v, full_prev_hyp_id) for k, v in states.items()},
{k: self.part_scorers[k].select_state(v, part_prev_hyp_id) for k, v in part_states.items()},
part_new_token_id)
))
return self.batchfy(best_hyps)
def post_process(self, i: int, maxlen: int, maxlenratio: float,
running_hyps: BatchHypothesis, ended_hyps: List[Hypothesis]) -> BatchHypothesis:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (BatchHypothesis): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
BatchHypothesis: The new running hypotheses.
"""
n_batch, maxlen = running_hyps.yseq.shape
logging.debug(f'the number of running hypothes: {n_batch}')
if self.token_list is not None:
logging.debug("best hypo: " + "".join(
[self.token_list[x] for x in running_hyps.yseq[0, 1:running_hyps.length[0]]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps.yseq.resize_(n_batch, maxlen + 1)
running_hyps.yseq[:, -1] = self.eos
running_hyps.yseq.index_fill_(1, running_hyps.length, self.eos)
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
is_eos = running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1] == self.eos
for b in torch.nonzero(is_eos).view(-1):
hyp = self._select(running_hyps, b)
ended_hyps.append(hyp)
remained_ids = torch.nonzero(is_eos == 0).view(-1)
return self._batch_select(running_hyps, remained_ids)
batch_beam(self, weighted_scores, ids)
¶
Batch-compute topk full token ids and partial token ids.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weighted_scores |
torch.Tensor |
The weighted sum scores for each tokens.
Its shape is |
required |
ids |
torch.Tensor |
The partial token ids to compute topk.
Its shape is |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] |
The topk full (prev_hyp, new_token) ids and partial (prev_hyp, new_token) ids.
Their shapes are all |
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def batch_beam(self, weighted_scores: torch.Tensor, ids: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Batch-compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(n_beam, self.vocab_size)`.
ids (torch.Tensor): The partial token ids to compute topk.
Its shape is `(n_beam, self.pre_beam_size)`.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
The topk full (prev_hyp, new_token) ids and partial (prev_hyp, new_token) ids.
Their shapes are all `(self.beam_size,)`
"""
if not self.do_pre_beam:
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
# Because of the flatten above, `top_ids` is organized as:
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
# where V is `self.n_vocab` and K is `self.beam_size`
prev_hyp_ids = top_ids // self.n_vocab
new_token_ids = top_ids % self.n_vocab
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
raise NotImplementedError("batch decoding with PartialScorer is not supported yet.")
batchfy(self, hyps)
¶
Convert list to batch.
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
"""Convert list to batch."""
if len(hyps) == 0:
return BatchHypothesis()
return BatchHypothesis(
yseq=pad_sequence([h.yseq for h in hyps], batch_first=True, padding_value=self.eos),
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
score=torch.tensor([h.score for h in hyps]),
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
states={k: [h.states[k] for h in hyps] for k in self.scorers}
)
init_hyp(self, x)
¶
Get an initial hypothesis data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
The encoder output feature |
required |
Returns:
Type | Description |
---|---|
Hypothesis |
The initial hypothesis. |
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
return self.batchfy(super().init_hyp(x))
post_process(self, i, maxlen, maxlenratio, running_hyps, ended_hyps)
¶
Perform post-processing of beam search iterations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
i |
int |
The length of hypothesis tokens. |
required |
maxlen |
int |
The maximum length of tokens in beam search. |
required |
maxlenratio |
int |
The maximum length ratio in beam search. |
required |
running_hyps |
BatchHypothesis |
The running hypotheses in beam search. |
required |
ended_hyps |
List[Hypothesis] |
The ended hypotheses in beam search. |
required |
Returns:
Type | Description |
---|---|
BatchHypothesis |
The new running hypotheses. |
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def post_process(self, i: int, maxlen: int, maxlenratio: float,
running_hyps: BatchHypothesis, ended_hyps: List[Hypothesis]) -> BatchHypothesis:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (BatchHypothesis): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
BatchHypothesis: The new running hypotheses.
"""
n_batch, maxlen = running_hyps.yseq.shape
logging.debug(f'the number of running hypothes: {n_batch}')
if self.token_list is not None:
logging.debug("best hypo: " + "".join(
[self.token_list[x] for x in running_hyps.yseq[0, 1:running_hyps.length[0]]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps.yseq.resize_(n_batch, maxlen + 1)
running_hyps.yseq[:, -1] = self.eos
running_hyps.yseq.index_fill_(1, running_hyps.length, self.eos)
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
is_eos = running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1] == self.eos
for b in torch.nonzero(is_eos).view(-1):
hyp = self._select(running_hyps, b)
ended_hyps.append(hyp)
remained_ids = torch.nonzero(is_eos == 0).view(-1)
return self._batch_select(running_hyps, remained_ids)
search(self, running_hyps, x)
¶
Search new tokens for running hypotheses and encoded speech x.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running_hyps |
BatchHypothesis |
Running hypotheses on beam |
required |
x |
torch.Tensor |
Encoded speech feature (T, D) |
required |
Returns:
Type | Description |
---|---|
BatchHypothesis |
Best sorted hypotheses |
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (BatchHypothesis): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
BatchHypothesis: Best sorted hypotheses
"""
n_batch = len(running_hyps)
# batch scoring
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
if self.do_pre_beam:
part_ids = torch.topk(scores[self.pre_beam_score_key], self.pre_beam_size, dim=-1)[1]
else:
part_ids = torch.arange(self.n_vocab, device=x.device).expand(n_batch, self.n_vocab)
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
# weighted sum scores
weighted_scores = torch.zeros(n_batch, self.n_vocab, dtype=x.dtype, device=x.device)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
weighted_scores += running_hyps.score.to(x.device).unsqueeze(1)
# TODO(karita): do not use list. use batch instead
# update hyps
best_hyps = []
prev_hyps = self.unbatchfy(running_hyps)
for full_prev_hyp_id, full_new_token_id, part_prev_hyp_id, part_new_token_id in zip(
*self.batch_beam(weighted_scores, part_ids)):
prev_hyp = prev_hyps[full_prev_hyp_id]
best_hyps.append(Hypothesis(
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
scores=self.merge_scores(
prev_hyp.scores,
{k: v[full_prev_hyp_id] for k, v in scores.items()}, full_new_token_id,
{k: v[part_prev_hyp_id] for k, v in part_scores.items()}, part_new_token_id),
states=self.merge_states(
{k: self.full_scorers[k].select_state(v, full_prev_hyp_id) for k, v in states.items()},
{k: self.part_scorers[k].select_state(v, part_prev_hyp_id) for k, v in part_states.items()},
part_new_token_id)
))
return self.batchfy(best_hyps)
unbatchfy(self, batch_hyps)
¶
Revert batch to list.
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
"""Revert batch to list."""
return [
Hypothesis(
yseq=batch_hyps.yseq[i][:batch_hyps.length[i]],
score=batch_hyps.score[i],
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
states={k: v.select_state(
batch_hyps.states[k], i) for k, v in self.scorers.items()}
) for i in range(len(batch_hyps.length))]
BatchHypothesis (tuple)
¶
Batchfied/Vectorized hypothesis data type.
Source code in adviser/tools/espnet_minimal/nets/batch_beam_search.py
class BatchHypothesis(NamedTuple):
"""Batchfied/Vectorized hypothesis data type."""
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
score: torch.Tensor = torch.tensor([]) # (batch,)
length: torch.Tensor = torch.tensor([]) # (batch,)
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
states: Dict[str, Dict] = dict()
def __len__(self) -> int:
"""Return a batch size."""
return len(self.length)
length: Tensor
¶
score: Tensor
¶
scores: Dict[str, torch.Tensor]
¶
states: Dict[str, Dict]
¶
yseq: Tensor
¶
__getnewargs__(self)
special
¶
__len__(self)
special
¶
__new__(_cls, yseq=tensor([]), score=tensor([]), length=tensor([]), scores={}, states={})
special
staticmethod
¶
Create new instance of BatchHypothesis(yseq, score, length, scores, states)
__repr__(self)
special
¶
beam_search
¶
Beam search module.
BeamSearch (Module)
¶
Beam search implementation.
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
class BeamSearch(torch.nn.Module):
"""Beam search implementation."""
def __init__(self, scorers: Dict[str, ScorerInterface], weights: Dict[str, float],
beam_size: int, vocab_size: int,
sos: int, eos: int, token_list: List[str] = None,
pre_beam_ratio: float = 1.5, pre_beam_score_key: str = None):
"""Initialize beam search.
Args:
scorers (dict[str, ScorerInterface]): Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
sos (int): Start of sequence id
eos (int): End of sequence id
token_list (list[str]): List of tokens for debug log
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search will be `int(pre_beam_ratio * beam_size)`
"""
super().__init__()
# set scorers
self.weights = weights
self.scorers = dict()
self.full_scorers = dict()
self.part_scorers = dict()
# this module dict is required for recursive cast `self.to(device, dtype)` in `recog.py`
self.nn_dict = torch.nn.ModuleDict()
for k, v in scorers.items():
w = weights.get(k, 0)
if w == 0 or v is None:
continue
assert isinstance(v, ScorerInterface), f"{k} ({type(v)}) does not implement ScorerInterface"
self.scorers[k] = v
if isinstance(v, PartialScorerInterface):
self.part_scorers[k] = v
else:
self.full_scorers[k] = v
if isinstance(v, torch.nn.Module):
self.nn_dict[k] = v
# set configurations
self.sos = sos
self.eos = eos
self.token_list = token_list
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
self.n_vocab = vocab_size
if pre_beam_score_key is not None and pre_beam_score_key not in self.full_scorers:
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
self.pre_beam_score_key = pre_beam_score_key
self.do_pre_beam = self.pre_beam_score_key is not None and \
self.pre_beam_size < self.n_vocab and len(self.part_scorers) > 0
def init_hyp(self, x: torch.Tensor) -> Hypothesis:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
return [Hypothesis(
score=0.0, scores=init_scores, states=init_states,
yseq=torch.tensor([self.sos], device=x.device))]
@staticmethod
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
"""Append new token to prefix tokens.
Args:
xs (torch.Tensor): The prefix token
x (int): The new token to append
Returns:
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
return torch.cat((xs, x))
def score_full(self, hyp: Hypothesis, x: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
return scores, states
def score_partial(self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor) \
-> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 1D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.part_scorers`
and tensor score values of shape: `(len(ids),)`,
and state dict that has string keys and state values of `self.part_scorers`
"""
scores = dict()
states = dict()
for k, d in self.part_scorers.items():
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
return scores, states
def beam(self, weighted_scores: torch.Tensor, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens. Its shape is `(self.n_vocab,)`.
ids (torch.Tensor): The partial token ids to compute topk
Returns:
Tuple[torch.Tensor, torch.Tensor]: The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`
"""
# no pre beam performed
if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1]
return top_ids, top_ids
# mask pruned in pre-beam not to select in topk
tmp = weighted_scores[ids]
weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp
top_ids = weighted_scores.topk(self.beam_size)[1]
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
return top_ids, local_ids
@staticmethod
def merge_scores(prev_scores: Dict[str, float], next_full_scores: Dict[str, torch.Tensor], full_idx: int,
next_part_scores: Dict[str, torch.Tensor], part_idx: int) -> Dict[str, torch.Tensor]:
"""Merge scores for new hypothesis.
Args:
prev_scores (Dict[str, float]): The previous hypothesis scores by `self.scorers`
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
full_idx (int): The next token id for `next_full_scores`
next_part_scores (Dict[str, torch.Tensor]): scores of partial tokens by `self.part_scorers`
part_idx (int): The new token id for `next_part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are scalar tensors by the scorers.
"""
new_scores = dict()
for k, v in next_full_scores.items():
new_scores[k] = prev_scores[k] + v[full_idx]
for k, v in next_part_scores.items():
new_scores[k] = v[part_idx]
return new_scores
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states = dict()
for k, v in states.items():
new_states[k] = v
for k, d in self.part_scorers.items():
new_states[k] = d.select_state(part_states[k], part_idx)
return new_states
def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (List[Hypothesis]): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
List[Hypotheses]: Best sorted hypotheses
"""
best_hyps = []
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
for hyp in running_hyps:
# scoring
scores, states = self.score_full(hyp, x)
if self.do_pre_beam:
part_ids = torch.topk(scores[self.pre_beam_score_key], self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, x)
# weighted sum scores
weighted_scores = torch.zeros(
self.n_vocab, dtype=x.dtype, device=x.device)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
weighted_scores += hyp.score
# update hyps
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
# will be (2 x beam at most)
best_hyps.append(Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores(
hyp.scores, scores, j, part_scores, part_j),
states=self.merge_states(states, part_states, part_j)))
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[:min(len(best_hyps), self.beam_size)]
return best_hyps
def forward(self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
minlen = int(minlenratio * x.size(0))
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(x)
ended_hyps = []
for i in range(maxlen):
logging.debug('position ' + str(i))
best = self.search(running_hyps, x)
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
logging.info(f'end detected at {i}')
break
if len(running_hyps) == 0:
logging.info('no hypothesis. Finish decoding.')
break
else:
logging.debug(f'remeined hypothes: {len(running_hyps)}')
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check number of hypotheis
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
return [] if minlenratio < 0.1 else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
# report the best result
best = nbest_hyps[0]
logging.info(f'total log probability: {best.score}')
logging.info(f'normalized log probability: {best.score / len(best.yseq)}')
return nbest_hyps
def post_process(self, i: int, maxlen: int, maxlenratio: float,
running_hyps: List[Hypothesis], ended_hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
List[Hypothesis]: The new running hypotheses.
"""
logging.debug(f'the number of running hypothes: {len(running_hyps)}')
if self.token_list is not None:
logging.debug("best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in running_hyps:
if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
s = d.final_score(hyp.states[k])
hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
return remained_hyps
__init__(self, scorers, weights, beam_size, vocab_size, sos, eos, token_list=None, pre_beam_ratio=1.5, pre_beam_score_key=None)
special
¶
Initialize beam search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scorers |
dict[str, ScorerInterface] |
Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is |
required |
weights |
dict[str, float] |
Dict of weights for each scorers The scorer will be ignored if its weight is 0 |
required |
beam_size |
int |
The number of hypotheses kept during search |
required |
vocab_size |
int |
The number of vocabulary |
required |
sos |
int |
Start of sequence id |
required |
eos |
int |
End of sequence id |
required |
token_list |
list[str] |
List of tokens for debug log |
None |
pre_beam_score_key |
str |
key of scores to perform pre-beam search |
None |
pre_beam_ratio |
float |
beam size in the pre-beam search will be |
1.5 |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def __init__(self, scorers: Dict[str, ScorerInterface], weights: Dict[str, float],
beam_size: int, vocab_size: int,
sos: int, eos: int, token_list: List[str] = None,
pre_beam_ratio: float = 1.5, pre_beam_score_key: str = None):
"""Initialize beam search.
Args:
scorers (dict[str, ScorerInterface]): Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
sos (int): Start of sequence id
eos (int): End of sequence id
token_list (list[str]): List of tokens for debug log
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search will be `int(pre_beam_ratio * beam_size)`
"""
super().__init__()
# set scorers
self.weights = weights
self.scorers = dict()
self.full_scorers = dict()
self.part_scorers = dict()
# this module dict is required for recursive cast `self.to(device, dtype)` in `recog.py`
self.nn_dict = torch.nn.ModuleDict()
for k, v in scorers.items():
w = weights.get(k, 0)
if w == 0 or v is None:
continue
assert isinstance(v, ScorerInterface), f"{k} ({type(v)}) does not implement ScorerInterface"
self.scorers[k] = v
if isinstance(v, PartialScorerInterface):
self.part_scorers[k] = v
else:
self.full_scorers[k] = v
if isinstance(v, torch.nn.Module):
self.nn_dict[k] = v
# set configurations
self.sos = sos
self.eos = eos
self.token_list = token_list
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
self.n_vocab = vocab_size
if pre_beam_score_key is not None and pre_beam_score_key not in self.full_scorers:
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
self.pre_beam_score_key = pre_beam_score_key
self.do_pre_beam = self.pre_beam_score_key is not None and \
self.pre_beam_size < self.n_vocab and len(self.part_scorers) > 0
append_token(xs, x)
staticmethod
¶
Append new token to prefix tokens.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
torch.Tensor |
The prefix token |
required |
x |
int |
The new token to append |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
New tensor contains: xs + [x] with xs.dtype and xs.device |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
@staticmethod
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
"""Append new token to prefix tokens.
Args:
xs (torch.Tensor): The prefix token
x (int): The new token to append
Returns:
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
return torch.cat((xs, x))
beam(self, weighted_scores, ids)
¶
Compute topk full token ids and partial token ids.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
weighted_scores |
torch.Tensor |
The weighted sum scores for each tokens. Its shape is |
required |
ids |
torch.Tensor |
The partial token ids to compute topk |
required |
Returns:
Type | Description |
---|---|
Tuple[torch.Tensor, torch.Tensor] |
The topk full token ids and partial token ids.
Their shapes are |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def beam(self, weighted_scores: torch.Tensor, ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens. Its shape is `(self.n_vocab,)`.
ids (torch.Tensor): The partial token ids to compute topk
Returns:
Tuple[torch.Tensor, torch.Tensor]: The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`
"""
# no pre beam performed
if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1]
return top_ids, top_ids
# mask pruned in pre-beam not to select in topk
tmp = weighted_scores[ids]
weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp
top_ids = weighted_scores.topk(self.beam_size)[1]
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
return top_ids, local_ids
forward(self, x, maxlenratio=0.0, minlenratio=0.0)
¶
Perform beam search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Encoded speech feature (T, D) |
required |
maxlenratio |
float |
Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths |
0.0 |
minlenratio |
float |
Input length ratio to obtain min output length. |
0.0 |
Returns:
Type | Description |
---|---|
list[Hypothesis] |
N-best decoding results |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def forward(self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
minlen = int(minlenratio * x.size(0))
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(x)
ended_hyps = []
for i in range(maxlen):
logging.debug('position ' + str(i))
best = self.search(running_hyps, x)
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
logging.info(f'end detected at {i}')
break
if len(running_hyps) == 0:
logging.info('no hypothesis. Finish decoding.')
break
else:
logging.debug(f'remeined hypothes: {len(running_hyps)}')
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check number of hypotheis
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
return [] if minlenratio < 0.1 else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
# report the best result
best = nbest_hyps[0]
logging.info(f'total log probability: {best.score}')
logging.info(f'normalized log probability: {best.score / len(best.yseq)}')
return nbest_hyps
init_hyp(self, x)
¶
Get an initial hypothesis data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
The encoder output feature |
required |
Returns:
Type | Description |
---|---|
Hypothesis |
The initial hypothesis. |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def init_hyp(self, x: torch.Tensor) -> Hypothesis:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
return [Hypothesis(
score=0.0, scores=init_scores, states=init_states,
yseq=torch.tensor([self.sos], device=x.device))]
merge_scores(prev_scores, next_full_scores, full_idx, next_part_scores, part_idx)
staticmethod
¶
Merge scores for new hypothesis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prev_scores |
Dict[str, float] |
The previous hypothesis scores by |
required |
next_full_scores |
Dict[str, torch.Tensor] |
scores by |
required |
full_idx |
int |
The next token id for |
required |
next_part_scores |
Dict[str, torch.Tensor] |
scores of partial tokens by |
required |
part_idx |
int |
The new token id for |
required |
Returns:
Type | Description |
---|---|
Dict[str, torch.Tensor] |
The new score dict.
Its keys are names of |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
@staticmethod
def merge_scores(prev_scores: Dict[str, float], next_full_scores: Dict[str, torch.Tensor], full_idx: int,
next_part_scores: Dict[str, torch.Tensor], part_idx: int) -> Dict[str, torch.Tensor]:
"""Merge scores for new hypothesis.
Args:
prev_scores (Dict[str, float]): The previous hypothesis scores by `self.scorers`
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
full_idx (int): The next token id for `next_full_scores`
next_part_scores (Dict[str, torch.Tensor]): scores of partial tokens by `self.part_scorers`
part_idx (int): The new token id for `next_part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are scalar tensors by the scorers.
"""
new_scores = dict()
for k, v in next_full_scores.items():
new_scores[k] = prev_scores[k] + v[full_idx]
for k, v in next_part_scores.items():
new_scores[k] = v[part_idx]
return new_scores
merge_states(self, states, part_states, part_idx)
¶
Merge states for new hypothesis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
states |
Any |
states of |
required |
part_states |
Any |
states of |
required |
part_idx |
int |
The new token id for |
required |
Returns:
Type | Description |
---|---|
Dict[str, torch.Tensor] |
The new score dict.
Its keys are names of |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states = dict()
for k, v in states.items():
new_states[k] = v
for k, d in self.part_scorers.items():
new_states[k] = d.select_state(part_states[k], part_idx)
return new_states
post_process(self, i, maxlen, maxlenratio, running_hyps, ended_hyps)
¶
Perform post-processing of beam search iterations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
i |
int |
The length of hypothesis tokens. |
required |
maxlen |
int |
The maximum length of tokens in beam search. |
required |
maxlenratio |
int |
The maximum length ratio in beam search. |
required |
running_hyps |
List[Hypothesis] |
The running hypotheses in beam search. |
required |
ended_hyps |
List[Hypothesis] |
The ended hypotheses in beam search. |
required |
Returns:
Type | Description |
---|---|
List[Hypothesis] |
The new running hypotheses. |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def post_process(self, i: int, maxlen: int, maxlenratio: float,
running_hyps: List[Hypothesis], ended_hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
List[Hypothesis]: The new running hypotheses.
"""
logging.debug(f'the number of running hypothes: {len(running_hyps)}')
if self.token_list is not None:
logging.debug("best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in running_hyps:
if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
s = d.final_score(hyp.states[k])
hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
return remained_hyps
score_full(self, hyp, x)
¶
Score new hypothesis by self.full_scorers
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hyp |
Hypothesis |
Hypothesis with prefix tokens to score |
required |
x |
torch.Tensor |
Corresponding input feature |
required |
Returns:
Type | Description |
---|---|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]] |
Tuple of
score dict of |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def score_full(self, hyp: Hypothesis, x: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
return scores, states
score_partial(self, hyp, ids, x)
¶
Score new hypothesis by self.part_scorers
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hyp |
Hypothesis |
Hypothesis with prefix tokens to score |
required |
ids |
torch.Tensor |
1D tensor of new partial tokens to score |
required |
x |
torch.Tensor |
Corresponding input feature |
required |
Returns:
Type | Description |
---|---|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]] |
Tuple of
score dict of |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def score_partial(self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor) \
-> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 1D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.part_scorers`
and tensor score values of shape: `(len(ids),)`,
and state dict that has string keys and state values of `self.part_scorers`
"""
scores = dict()
states = dict()
for k, d in self.part_scorers.items():
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
return scores, states
search(self, running_hyps, x)
¶
Search new tokens for running hypotheses and encoded speech x.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
running_hyps |
List[Hypothesis] |
Running hypotheses on beam |
required |
x |
torch.Tensor |
Encoded speech feature (T, D) |
required |
Returns:
Type | Description |
---|---|
List[Hypotheses] |
Best sorted hypotheses |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (List[Hypothesis]): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
List[Hypotheses]: Best sorted hypotheses
"""
best_hyps = []
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
for hyp in running_hyps:
# scoring
scores, states = self.score_full(hyp, x)
if self.do_pre_beam:
part_ids = torch.topk(scores[self.pre_beam_score_key], self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, x)
# weighted sum scores
weighted_scores = torch.zeros(
self.n_vocab, dtype=x.dtype, device=x.device)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
weighted_scores += hyp.score
# update hyps
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
# will be (2 x beam at most)
best_hyps.append(Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores(
hyp.scores, scores, j, part_scores, part_j),
states=self.merge_states(states, part_states, part_j)))
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[:min(len(best_hyps), self.beam_size)]
return best_hyps
Hypothesis (tuple)
¶
Hypothesis data type.
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: torch.Tensor
score: float = 0
scores: Dict[str, float] = dict()
states: Dict[str, Dict] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()}
)._asdict()
score: float
¶
scores: Dict[str, float]
¶
states: Dict[str, Dict]
¶
yseq: Tensor
¶
__getnewargs__(self)
special
¶
__new__(_cls, yseq, score=0, scores={}, states={})
special
staticmethod
¶
Create new instance of Hypothesis(yseq, score, scores, states)
__repr__(self)
special
¶
asdict(self)
¶
Convert data to JSON-friendly dict.
beam_search(x, sos, eos, beam_size, vocab_size, scorers, weights, token_list=None, maxlenratio=0.0, minlenratio=0.0, pre_beam_ratio=1.5, pre_beam_score_key='decoder')
¶
Perform beam search with scorers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Encoded speech feature (T, D) |
required |
sos |
int |
Start of sequence id |
required |
eos |
int |
End of sequence id |
required |
beam_size |
int |
The number of hypotheses kept during search |
required |
vocab_size |
int |
The number of vocabulary |
required |
scorers |
dict[str, ScorerInterface] |
Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is |
required |
weights |
dict[str, float] |
Dict of weights for each scorers The scorer will be ignored if its weight is 0 |
required |
token_list |
list[str] |
List of tokens for debug log |
None |
maxlenratio |
float |
Input length ratio to obtain max output length. If maxlenratio=0.0 (default), it uses a end-detect function to automatically find maximum hypothesis lengths |
0.0 |
minlenratio |
float |
Input length ratio to obtain min output length. |
0.0 |
pre_beam_score_key |
str |
key of scores to perform pre-beam search |
'decoder' |
pre_beam_ratio |
float |
beam size in the pre-beam search will be |
1.5 |
Returns:
Type | Description |
---|---|
list |
N-best decoding results |
Source code in adviser/tools/espnet_minimal/nets/beam_search.py
def beam_search(x: torch.Tensor, sos: int, eos: int, beam_size: int, vocab_size: int,
scorers: Dict[str, ScorerInterface], weights: Dict[str, float],
token_list: List[str] = None, maxlenratio: float = 0.0, minlenratio: float = 0.0,
pre_beam_ratio: float = 1.5, pre_beam_score_key: str = "decoder") -> list:
"""Perform beam search with scorers.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
sos (int): Start of sequence id
eos (int): End of sequence id
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
scorers (dict[str, ScorerInterface]): Dict of decoder modules e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
token_list (list[str]): List of tokens for debug log
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search will be `int(pre_beam_ratio * beam_size)`
Returns:
list: N-best decoding results
"""
ret = BeamSearch(
scorers, weights,
beam_size=beam_size,
vocab_size=vocab_size,
pre_beam_ratio=pre_beam_ratio,
pre_beam_score_key=pre_beam_score_key,
sos=sos,
eos=eos,
token_list=token_list,
).forward(
x=x,
maxlenratio=maxlenratio,
minlenratio=minlenratio)
return [h.asdict() for h in ret]
ctc_prefix_score
¶
CTCPrefixScore
¶
Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al. "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," but extended to efficiently compute the probablities of multiple labels simultaneously
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
class CTCPrefixScore(object):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def __init__(self, x, blank, eos, xp):
self.xp = xp
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.input_length = len(x)
self.x = x
def initial_state(self):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank]
for i in six.moves.range(1, self.input_length):
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
return r
def __call__(self, y, cs, r_prev):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs]
if output_length == 0:
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label
r_sum = self.xp.logaddexp(r_prev[:, 0], r_prev[:, 1]) # log(r_t^n(g) + r_t^b(g))
last = y[-1]
if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
for i in six.moves.range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else:
log_phi = r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilites log(psi)
start = max(output_length, 1)
log_psi = r[start - 1, 0]
for t in six.moves.range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself
eos_pos = self.xp.where(cs == self.eos)[0]
if len(eos_pos) > 0:
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return log_psi, self.xp.rollaxis(r, 2)
__call__(self, y, cs, r_prev)
special
¶
Compute CTC prefix scores for next labels
:param y : prefix label sequence :param cs : array of next labels :param r_prev: previous CTC state :return ctc_scores, ctc_states
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
def __call__(self, y, cs, r_prev):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs]
if output_length == 0:
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label
r_sum = self.xp.logaddexp(r_prev[:, 0], r_prev[:, 1]) # log(r_t^n(g) + r_t^b(g))
last = y[-1]
if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
for i in six.moves.range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else:
log_phi = r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilites log(psi)
start = max(output_length, 1)
log_psi = r[start - 1, 0]
for t in six.moves.range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself
eos_pos = self.xp.where(cs == self.eos)[0]
if len(eos_pos) > 0:
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return log_psi, self.xp.rollaxis(r, 2)
__init__(self, x, blank, eos, xp)
special
¶
initial_state(self)
¶
Obtain an initial CTC state
:return: CTC state
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
def initial_state(self):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank]
for i in six.moves.range(1, self.input_length):
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
return r
CTCPrefixScoreTH
¶
Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al. "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," but extended to efficiently compute the probablities of multiple labels simultaneously
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
class CTCPrefixScoreTH(object):
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def __init__(self, x, xlens, blank, eos, beam, scoring_ratio=1.5, margin=0):
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
:param torch.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int beam: beam size
:param float scoring_ratio: ratio of #scored hypos to beam size
:param int margin: margin parameter for windowing (0 means no windowing)
"""
# In the comment lines, we assume T: input_length, B: batch size, W: beam width, O: output dim.
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.batch = x.size(0)
self.input_length = x.size(1)
self.odim = x.size(2)
self.beam = beam
self.n_bb = self.batch * beam
self.device = torch.device('cuda:%d' % x.get_device()) if x.is_cuda else torch.device('cpu')
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, blank] = 0
# Set the number of scoring hypotheses (scoring_num=0 means all)
self.scoring_num = int(beam * scoring_ratio)
if self.scoring_num >= self.odim:
self.scoring_num = 0
# Expand input posteriors for fast computation
if self.scoring_num == 0:
xn = x.transpose(0, 1).unsqueeze(2).repeat(1, 1, beam, 1).view(-1, self.n_bb, self.odim)
else:
xn = x.transpose(0, 1)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O) or (2, T, BW, O)
# Setup CTC windowing
self.margin = margin
if margin > 0:
self.frame_ids = torch.arange(self.input_length, dtype=torch.float32, device=self.device)
# Precompute end frames (BW,)
self.end_frames = (torch.as_tensor(xlens) - 1).view(self.batch, 1).repeat(1, beam).view(-1)
# Precompute base indices to convert label ids to corresponding element indices
self.pad_b = (torch.arange(self.batch, device=self.device) * beam).view(-1, 1)
self.pad_bo = (torch.arange(self.batch, device=self.device) * (beam * self.odim)).view(-1, 1)
self.pad_o = (torch.arange(self.batch, device=self.device) * self.odim).unsqueeze(1).repeat(1, beam).view(-1, 1)
self.bb_idx = torch.arange(self.n_bb, device=self.device).view(-1, 1)
def __call__(self, y, state, pre_scores=None, att_w=None):
"""Compute CTC prefix scores for next labels
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
:param torch.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
"""
output_length = len(y[0]) - 1 # ignore sos
last_ids = [yi[-1] for yi in y] # last output label ids
# prepare state info
if state is None:
if self.scoring_num > 0:
r_prev = torch.full((self.input_length, 2, self.batch, self.beam),
self.logzero, dtype=torch.float32, device=self.device)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, self.n_bb)
else:
r_prev = torch.full((self.input_length, 2, self.n_bb),
self.logzero, dtype=torch.float32, device=self.device)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0)
s_prev = 0.0
f_min_prev = 0
f_max_prev = 1
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for scoring
if self.scoring_num > 0 and pre_scores is not None:
pre_scores[:, self.blank] = self.logzero # ignore blank from pre-selection
scoring_ids = torch.topk(pre_scores, self.scoring_num, 1)[1]
scoring_idmap = torch.full((self.n_bb, self.odim), -1, dtype=torch.long, device=self.device)
snum = scoring_ids.size(1)
scoring_idmap[self.bb_idx, scoring_ids] = torch.arange(snum, device=self.device)
scoring_idx = (scoring_ids + self.pad_o).view(-1)
x_ = torch.index_select(self.x.view(2, -1, self.batch * self.odim),
2, scoring_idx).view(2, -1, self.n_bb, snum)
else:
scoring_ids = None
scoring_idmap = None
snum = self.odim
x_ = self.x
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = torch.full((self.input_length, 2, self.n_bb, snum),
self.logzero, dtype=torch.float32, device=self.device)
if output_length == 0:
r[0, 0] = x_[0, 0]
r_sum = torch.logsumexp(r_prev, 1)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
if scoring_ids is not None:
for idx in range(self.n_bb):
pos = scoring_idmap[idx, last_ids[idx]]
if pos >= 0:
log_phi[:, idx, pos] = r_prev[:, 1, idx]
else:
for idx in range(self.n_bb):
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
# decide start and end frames based on attention weights
if att_w is not None and self.margin > 0:
f_arg = torch.matmul(att_w, self.frame_ids)
f_min = max(int(f_arg.min().cpu()), f_min_prev)
f_max = max(int(f_arg.max().cpu()), f_max_prev)
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
end = min(f_max + self.margin, self.input_length)
else:
f_min = f_max = 0
start = max(output_length, 1)
end = self.input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end):
rp = r[t - 1]
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, self.n_bb, snum)
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilites log(psi)
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
if scoring_ids is not None:
log_psi = torch.full((self.n_bb, self.odim), self.logzero, device=self.device)
log_psi_ = torch.logsumexp(torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), dim=0)
for si in range(self.n_bb):
log_psi[si, scoring_ids[si]] = log_psi_[si]
else:
log_psi = torch.logsumexp(torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), dim=0)
for si in range(self.n_bb):
log_psi[si, self.eos] = r_sum[self.end_frames[si], si]
return (r, log_psi, f_min, f_max, scoring_idmap), log_psi - s_prev
def index_select_state(self, state, best_ids):
"""Select CTC states according to best ids
:param state : CTC state
:param best_ids : index numbers selected by beam pruning (B, W)
:return selected_state
"""
r, s, f_min, f_max, scoring_idmap = state
# convert ids to BWO space
vidx = (best_ids + self.pad_bo).view(-1)
# select hypothesis scores
s_new = torch.index_select(s.view(-1), 0, vidx)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(self.n_bb, self.odim)
# convert ids to BWS space (S: scoring_num)
if scoring_idmap is not None:
snum = self.scoring_num
beam_idx = (torch.div(best_ids, self.odim) + self.pad_b).view(-1)
label_ids = torch.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[beam_idx, label_ids]
score_idx[score_idx == -1] = 0
vidx = score_idx + beam_idx * snum
else:
snum = self.odim
# select forward probabilities
r_new = torch.index_select(r.view(-1, 2, self.n_bb * snum), 2, vidx).view(-1, 2, self.n_bb)
return r_new, s_new, f_min, f_max
__call__(self, y, state, pre_scores=None, att_w=None)
special
¶
Compute CTC prefix scores for next labels
:param list y: prefix label sequences :param tuple state: previous CTC state :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) :param torch.Tensor att_w: attention weights to decide CTC window :return new_state, ctc_local_scores (BW, O)
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
def __call__(self, y, state, pre_scores=None, att_w=None):
"""Compute CTC prefix scores for next labels
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
:param torch.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
"""
output_length = len(y[0]) - 1 # ignore sos
last_ids = [yi[-1] for yi in y] # last output label ids
# prepare state info
if state is None:
if self.scoring_num > 0:
r_prev = torch.full((self.input_length, 2, self.batch, self.beam),
self.logzero, dtype=torch.float32, device=self.device)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, self.n_bb)
else:
r_prev = torch.full((self.input_length, 2, self.n_bb),
self.logzero, dtype=torch.float32, device=self.device)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0)
s_prev = 0.0
f_min_prev = 0
f_max_prev = 1
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for scoring
if self.scoring_num > 0 and pre_scores is not None:
pre_scores[:, self.blank] = self.logzero # ignore blank from pre-selection
scoring_ids = torch.topk(pre_scores, self.scoring_num, 1)[1]
scoring_idmap = torch.full((self.n_bb, self.odim), -1, dtype=torch.long, device=self.device)
snum = scoring_ids.size(1)
scoring_idmap[self.bb_idx, scoring_ids] = torch.arange(snum, device=self.device)
scoring_idx = (scoring_ids + self.pad_o).view(-1)
x_ = torch.index_select(self.x.view(2, -1, self.batch * self.odim),
2, scoring_idx).view(2, -1, self.n_bb, snum)
else:
scoring_ids = None
scoring_idmap = None
snum = self.odim
x_ = self.x
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = torch.full((self.input_length, 2, self.n_bb, snum),
self.logzero, dtype=torch.float32, device=self.device)
if output_length == 0:
r[0, 0] = x_[0, 0]
r_sum = torch.logsumexp(r_prev, 1)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
if scoring_ids is not None:
for idx in range(self.n_bb):
pos = scoring_idmap[idx, last_ids[idx]]
if pos >= 0:
log_phi[:, idx, pos] = r_prev[:, 1, idx]
else:
for idx in range(self.n_bb):
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
# decide start and end frames based on attention weights
if att_w is not None and self.margin > 0:
f_arg = torch.matmul(att_w, self.frame_ids)
f_min = max(int(f_arg.min().cpu()), f_min_prev)
f_max = max(int(f_arg.max().cpu()), f_max_prev)
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
end = min(f_max + self.margin, self.input_length)
else:
f_min = f_max = 0
start = max(output_length, 1)
end = self.input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end):
rp = r[t - 1]
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, self.n_bb, snum)
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilites log(psi)
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
if scoring_ids is not None:
log_psi = torch.full((self.n_bb, self.odim), self.logzero, device=self.device)
log_psi_ = torch.logsumexp(torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), dim=0)
for si in range(self.n_bb):
log_psi[si, scoring_ids[si]] = log_psi_[si]
else:
log_psi = torch.logsumexp(torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), dim=0)
for si in range(self.n_bb):
log_psi[si, self.eos] = r_sum[self.end_frames[si], si]
return (r, log_psi, f_min, f_max, scoring_idmap), log_psi - s_prev
__init__(self, x, xlens, blank, eos, beam, scoring_ratio=1.5, margin=0)
special
¶
Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O) :param torch.Tensor xlens: input lengths (B,) :param int blank: blank label id :param int eos: end-of-sequence id :param int beam: beam size :param float scoring_ratio: ratio of #scored hypos to beam size :param int margin: margin parameter for windowing (0 means no windowing)
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
def __init__(self, x, xlens, blank, eos, beam, scoring_ratio=1.5, margin=0):
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
:param torch.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int beam: beam size
:param float scoring_ratio: ratio of #scored hypos to beam size
:param int margin: margin parameter for windowing (0 means no windowing)
"""
# In the comment lines, we assume T: input_length, B: batch size, W: beam width, O: output dim.
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.batch = x.size(0)
self.input_length = x.size(1)
self.odim = x.size(2)
self.beam = beam
self.n_bb = self.batch * beam
self.device = torch.device('cuda:%d' % x.get_device()) if x.is_cuda else torch.device('cpu')
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, blank] = 0
# Set the number of scoring hypotheses (scoring_num=0 means all)
self.scoring_num = int(beam * scoring_ratio)
if self.scoring_num >= self.odim:
self.scoring_num = 0
# Expand input posteriors for fast computation
if self.scoring_num == 0:
xn = x.transpose(0, 1).unsqueeze(2).repeat(1, 1, beam, 1).view(-1, self.n_bb, self.odim)
else:
xn = x.transpose(0, 1)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O) or (2, T, BW, O)
# Setup CTC windowing
self.margin = margin
if margin > 0:
self.frame_ids = torch.arange(self.input_length, dtype=torch.float32, device=self.device)
# Precompute end frames (BW,)
self.end_frames = (torch.as_tensor(xlens) - 1).view(self.batch, 1).repeat(1, beam).view(-1)
# Precompute base indices to convert label ids to corresponding element indices
self.pad_b = (torch.arange(self.batch, device=self.device) * beam).view(-1, 1)
self.pad_bo = (torch.arange(self.batch, device=self.device) * (beam * self.odim)).view(-1, 1)
self.pad_o = (torch.arange(self.batch, device=self.device) * self.odim).unsqueeze(1).repeat(1, beam).view(-1, 1)
self.bb_idx = torch.arange(self.n_bb, device=self.device).view(-1, 1)
index_select_state(self, state, best_ids)
¶
Select CTC states according to best ids
:param state : CTC state :param best_ids : index numbers selected by beam pruning (B, W) :return selected_state
Source code in adviser/tools/espnet_minimal/nets/ctc_prefix_score.py
def index_select_state(self, state, best_ids):
"""Select CTC states according to best ids
:param state : CTC state
:param best_ids : index numbers selected by beam pruning (B, W)
:return selected_state
"""
r, s, f_min, f_max, scoring_idmap = state
# convert ids to BWO space
vidx = (best_ids + self.pad_bo).view(-1)
# select hypothesis scores
s_new = torch.index_select(s.view(-1), 0, vidx)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(self.n_bb, self.odim)
# convert ids to BWS space (S: scoring_num)
if scoring_idmap is not None:
snum = self.scoring_num
beam_idx = (torch.div(best_ids, self.odim) + self.pad_b).view(-1)
label_ids = torch.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[beam_idx, label_ids]
score_idx[score_idx == -1] = 0
vidx = score_idx + beam_idx * snum
else:
snum = self.odim
# select forward probabilities
r_new = torch.index_select(r.view(-1, 2, self.n_bb * snum), 2, vidx).view(-1, 2, self.n_bb)
return r_new, s_new, f_min, f_max
e2e_asr_common
¶
ErrorCalculator
¶
Calculate CER and WER for E2E_ASR and CTC models during training
:param y_hats: numpy array with predicted text :param y_pads: numpy array with true (target) text :param char_list: :param sym_space: :param sym_blank: :return:
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
class ErrorCalculator(object):
"""Calculate CER and WER for E2E_ASR and CTC models during training
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list:
:param sym_space:
:param sym_blank:
:return:
"""
def __init__(self, char_list, sym_space, sym_blank,
report_cer=False, report_wer=False):
super(ErrorCalculator, self).__init__()
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
# To avoid wrong higger WER than the one obtained from the decoding
# eos from y_true is used to mark the eos in y_hat
# because of that y_hats has not padded outs with -1.
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
seq_true_text = "".join(seq_true).replace(self.space, ' ')
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
__call__(self, ys_hat, ys_pad, is_ctc=False)
special
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def __call__(self, ys_hat, ys_pad, is_ctc=False):
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
__init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False)
special
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def __init__(self, char_list, sym_space, sym_blank,
report_cer=False, report_wer=False):
super(ErrorCalculator, self).__init__()
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
self.idx_blank = self.char_list.index(self.blank)
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
calculate_cer(self, seqs_hat, seqs_true)
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def calculate_cer(self, seqs_hat, seqs_true):
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
calculate_cer_ctc(self, ys_hat, ys_pad)
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def calculate_cer_ctc(self, ys_hat, ys_pad):
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
calculate_wer(self, seqs_hat, seqs_true)
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def calculate_wer(self, seqs_hat, seqs_true):
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
convert_to_char(self, ys_hat, ys_pad)
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def convert_to_char(self, ys_hat, ys_pad):
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
# To avoid wrong higger WER than the one obtained from the decoding
# eos from y_true is used to mark the eos in y_hat
# because of that y_hats has not padded outs with -1.
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
seq_true_text = "".join(seq_true).replace(self.space, ' ')
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
end_detect(ended_hyps, i, M=3, D_end=-10.0)
¶
End detection
desribed in Eq. (50) of S. Watanabe et al "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps: :param i: :param M: :param D_end: :return:
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection
desribed in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param D_end:
:return:
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
for m in six.moves.range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
count += 1
if count == M:
return True
else:
return False
get_vgg2l_odim(idim, in_channel=3, out_channel=128)
¶
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
label_smoothing_dist(odim, lsm_type, transcript=None, blank=0)
¶
Obtain label distribution for loss smoothing
:param odim: :param lsm_type: :param blank: :param transcript: :return:
Source code in adviser/tools/espnet_minimal/nets/e2e_asr_common.py
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, 'rb') as f:
trans_json = json.load(f)['utts']
if lsm_type == 'unigram':
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error(
"Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist
mt_interface
¶
MT Interface module.
MTInterface
¶
MT Interface for ESPnet model implementation.
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
class MTInterface:
"""MT Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
def translate(self, x, trans_args, char_list=None, rnnlm=None):
"""Translate x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from tools.espnet_minimal.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
attention_plot_class
property
readonly
¶
Get attention plot class.
add_arguments(parser)
staticmethod
¶
build(idim, odim, **kwargs)
classmethod
¶
Initialize this class with python-level args.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
The number of an input feature dim. |
required |
odim |
int |
The number of output vocab. |
required |
Returns:
Type | Description |
---|---|
ASRinterface |
A new instance of ASRInterface. |
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
calculate_all_attentions(self, xs, ilens, ys)
¶
Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...] :param ndarray ilens: batch of lengths of input sequences (B) :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] :return: attention weights (B, Lmax, Tmax) :rtype: float ndarray
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
def calculate_all_attentions(self, xs, ilens, ys):
"""Caluculate attention.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
forward(self, xs, ilens, ys)
¶
Compute loss for training.
:param xs: For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim) For chainer, list of source sequences chainer.Variable :param ilens: batch of lengths of source sequences (B) For pytorch, torch.Tensor For chainer, list of int :param ys: For pytorch, batch of padded source sequences torch.Tensor (B, Lmax) For chainer, list of source sequences chainer.Variable :return: loss value :rtype: torch.Tensor for pytorch, chainer.Variable for chainer
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
translate(self, x, trans_args, char_list=None, rnnlm=None)
¶
Translate x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D) :param namespace trans_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
def translate(self, x, trans_args, char_list=None, rnnlm=None):
"""Translate x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
translate_batch(self, x, trans_args, char_list=None, rnnlm=None)
¶
Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc) :param namespace trans_args: argument namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/mt_interface.py
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
pytorch_backend
special
¶
ctc
¶
CTC (Module)
¶
CTC module
:param int odim: dimension of outputs :param int eprojs: number of encoder projection units :param float dropout_rate: dropout rate (0.0 ~ 1.0) :param str ctc_type: builtin or warpctc :param bool reduce: reduce the CTC loss into a scalar
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
class CTC(torch.nn.Module):
"""CTC module
:param int odim: dimension of outputs
:param int eprojs: number of encoder projection units
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
:param str ctc_type: builtin or warpctc
:param bool reduce: reduce the CTC loss into a scalar
"""
def __init__(self, odim, eprojs, dropout_rate, ctc_type='warpctc', reduce=True):
super().__init__()
self.dropout_rate = dropout_rate
self.loss = None
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
if self.ctc_type == 'builtin':
reduction_type = 'sum' if reduce else 'none'
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
elif self.ctc_type == 'warpctc':
import warpctc_pytorch as warp_ctc
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
else:
raise ValueError('ctc_type must be "builtin" or "warpctc": {}'
.format(self.ctc_type))
self.ignore_id = -1
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
if self.ctc_type == 'builtin':
th_pred = th_pred.log_softmax(2)
# Use the deterministic CuDNN implementation of CTC loss to avoid
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
with torch.backends.cudnn.flags(deterministic=True):
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
# Batch-size average
loss = loss / th_pred.size(1)
return loss
elif self.ctc_type == 'warpctc':
return self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad):
"""CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
self.loss = None
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
olens = torch.from_numpy(np.fromiter(
(x.size(0) for x in ys), dtype=np.int32))
# zero padding for hs
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
# zero padding for ys
ys_true = torch.cat(ys).cpu().int() # batch x olen
# get length info
logging.info(self.__class__.__name__ + ' input lengths: ' + ''.join(str(hlens).split('\n')))
logging.info(self.__class__.__name__ + ' output lengths: ' + ''.join(str(olens).split('\n')))
# get ctc loss
# expected shape of seqLength x batchSize x alphabet_size
dtype = ys_hat.dtype
ys_hat = ys_hat.transpose(0, 1)
if self.ctc_type == "warpctc":
# warpctc only supports float32
ys_hat = ys_hat.to(dtype=torch.float32)
else:
# use GPU when using the cuDNN implementation
ys_true = to_device(self, ys_true)
self.loss = to_device(self, self.loss_fn(ys_hat, ys_true, hlens, olens)).to(dtype=dtype)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
self.loss = self.loss.sum()
logging.info('ctc loss:' + str(float(self.loss)))
return self.loss
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: argmax applied 2d tensor (B, Tmax)
:rtype: torch.Tensor
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
__init__(self, odim, eprojs, dropout_rate, ctc_type='warpctc', reduce=True)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
def __init__(self, odim, eprojs, dropout_rate, ctc_type='warpctc', reduce=True):
super().__init__()
self.dropout_rate = dropout_rate
self.loss = None
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.ctc_type = ctc_type
if self.ctc_type == 'builtin':
reduction_type = 'sum' if reduce else 'none'
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
elif self.ctc_type == 'warpctc':
import warpctc_pytorch as warp_ctc
self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce)
else:
raise ValueError('ctc_type must be "builtin" or "warpctc": {}'
.format(self.ctc_type))
self.ignore_id = -1
self.reduce = reduce
argmax(self, hs_pad)
¶argmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) :return: argmax applied 2d tensor (B, Tmax) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
forward(self, hs_pad, hlens, ys_pad)
¶CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :return: ctc loss value :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
def forward(self, hs_pad, hlens, ys_pad):
"""CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
self.loss = None
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
olens = torch.from_numpy(np.fromiter(
(x.size(0) for x in ys), dtype=np.int32))
# zero padding for hs
ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
# zero padding for ys
ys_true = torch.cat(ys).cpu().int() # batch x olen
# get length info
logging.info(self.__class__.__name__ + ' input lengths: ' + ''.join(str(hlens).split('\n')))
logging.info(self.__class__.__name__ + ' output lengths: ' + ''.join(str(olens).split('\n')))
# get ctc loss
# expected shape of seqLength x batchSize x alphabet_size
dtype = ys_hat.dtype
ys_hat = ys_hat.transpose(0, 1)
if self.ctc_type == "warpctc":
# warpctc only supports float32
ys_hat = ys_hat.to(dtype=torch.float32)
else:
# use GPU when using the cuDNN implementation
ys_true = to_device(self, ys_true)
self.loss = to_device(self, self.loss_fn(ys_hat, ys_true, hlens, olens)).to(dtype=dtype)
if self.reduce:
# NOTE: sum() is needed to keep consistency since warpctc return as tensor w/ shape (1,)
# but builtin return as tensor w/o shape (scalar).
self.loss = self.loss.sum()
logging.info('ctc loss:' + str(float(self.loss)))
return self.loss
log_softmax(self, hs_pad)
¶log_softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) :return: log softmax applied 3d tensor (B, Tmax, odim) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
loss_fn(self, th_pred, th_target, th_ilen, th_olen)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
if self.ctc_type == 'builtin':
th_pred = th_pred.log_softmax(2)
# Use the deterministic CuDNN implementation of CTC loss to avoid
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
with torch.backends.cudnn.flags(deterministic=True):
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
# Batch-size average
loss = loss / th_pred.size(1)
return loss
elif self.ctc_type == 'warpctc':
return self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
else:
raise NotImplementedError
ctc_for(args, odim, reduce=True)
¶
Returns the CTC module for the given args and output dimension
:param Namespace args: the program args :param int odim : The output dimension :param bool reduce : return the CTC loss in a scalar :return: the corresponding CTC module
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/ctc.py
def ctc_for(args, odim, reduce=True):
"""Returns the CTC module for the given args and output dimension
:param Namespace args: the program args
:param int odim : The output dimension
:param bool reduce : return the CTC loss in a scalar
:return: the corresponding CTC module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return CTC(odim, args.eprojs, args.dropout_rate, ctc_type='builtin', reduce=reduce)
# changed this to use builtin ctc rather
# than warpctc, so we have nothing to
# install and it's just about the loss anyways.
elif num_encs >= 1:
ctcs_list = torch.nn.ModuleList()
if args.share_ctc:
# use dropout_rate of the first encoder
ctc = CTC(odim, args.eprojs, args.dropout_rate[0], ctc_type=args.ctc_type, reduce=reduce)
ctcs_list.append(ctc)
else:
for idx in range(num_encs):
ctc = CTC(odim, args.eprojs, args.dropout_rate[idx], ctc_type=args.ctc_type, reduce=reduce)
ctcs_list.append(ctc)
return ctcs_list
else:
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
e2e_asr
¶
RNN sequence-to-sequence speech recognition model (pytorch).
CTC_LOSS_THRESHOLD
¶
E2E (ASRInterface, Module)
¶
E2E module.
:param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2E.encoder_add_arguments(parser)
E2E.attention_add_arguments(parser)
E2E.decoder_add_arguments(parser)
return parser
@staticmethod
def encoder_add_arguments(parser):
"""Add arguments for the encoder."""
group = parser.add_argument_group("E2E encoder setting")
# encoder
group.add_argument('--etype', default='blstmp', type=str,
choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm',
'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'],
help='Type of encoder network architecture')
group.add_argument('--elayers', default=4, type=int,
help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)')
group.add_argument('--eunits', '-u', default=300, type=int,
help='Number of encoder hidden units')
group.add_argument('--eprojs', default=320, type=int,
help='Number of encoder projection units')
group.add_argument('--subsample', default="1", type=str,
help='Subsample input frames x_y_z means subsample every x frame at 1st layer, '
'every y frame at 2nd layer etc.')
return parser
@staticmethod
def attention_add_arguments(parser):
"""Add arguments for the attention."""
group = parser.add_argument_group("E2E attention setting")
# attention
group.add_argument('--atype', default='dot', type=str,
choices=['noatt', 'dot', 'add', 'location', 'coverage',
'coverage_location', 'location2d', 'location_recurrent',
'multi_head_dot', 'multi_head_add', 'multi_head_loc',
'multi_head_multi_res_loc'],
help='Type of attention architecture')
group.add_argument('--adim', default=320, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--awin', default=5, type=int,
help='Window size for location2d attention')
group.add_argument('--aheads', default=4, type=int,
help='Number of heads for multi head attention')
group.add_argument('--aconv-chans', default=-1, type=int,
help='Number of attention convolution channels \
(negative value indicates no location-aware attention)')
group.add_argument('--aconv-filts', default=100, type=int,
help='Number of attention convolution filters \
(negative value indicates no location-aware attention)')
group.add_argument('--dropout-rate', default=0.0, type=float,
help='Dropout rate for the encoder')
return parser
@staticmethod
def decoder_add_arguments(parser):
"""Add arguments for the decoder."""
group = parser.add_argument_group("E2E encoder setting")
group.add_argument('--dtype', default='lstm', type=str,
choices=['lstm', 'gru'],
help='Type of decoder network architecture')
group.add_argument('--dlayers', default=1, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=320, type=int,
help='Number of decoder hidden units')
group.add_argument('--dropout-rate-decoder', default=0.0, type=float,
help='Dropout rate for the decoder')
group.add_argument('--sampling-probability', default=0.0, type=float,
help='Ratio of predicted labels fed back to decoder')
group.add_argument('--lsm-type', const='', default='', type=str, nargs='?',
choices=['', 'unigram'],
help='Apply label smoothing with a specified distribution type')
return parser
def __init__(self, idim, odim, args):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super(E2E, self).__init__() # This loads default arguments,
# but not calling this yields the same error, so it's not why things break.
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
# +1 means input (+1) and layers outputs (args.elayer)
subsample = np.ones(args.elayers + 1, dtype=np.int)
if args.etype.endswith("p") and not args.etype.startswith("vgg"):
ss = args.subsample.split("_")
for j in range(min(args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
self.subsample = subsample
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json)
else:
labeldist = None
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
# Relative importing because of using python3 syntax
from tools.espnet_minimal.nets.pytorch_backend.frontends.feature_transform \
import feature_transform_for
from tools.espnet_minimal.nets.pytorch_backend.frontends.frontend \
import frontend_for
self.frontend = frontend_for(args, idim)
self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
idim = args.n_mels
else:
self.frontend = None
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
# self.ctc = ctc_for(args, odim) <-- if this is executed, the shapes don't match.
# The missing/unexpected arguments are not fixed by removing this however.
# attention
self.att = att_for(args)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for l in six.moves.range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[l].bias_ih)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# 2. CTC loss
if self.mtlalpha == 0:
self.loss_ctc = None
else:
self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)
# 3. attention loss
if self.mtlalpha == 1:
self.loss_att, acc = None, None
else:
self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc = None
else:
cers = []
y_hats = self.ctc.argmax(hs_pad).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
seq_true_text = "".join(seq_true).replace(self.space, ' ')
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
cer_ctc = sum(cers) / len(cers) if cers else None
# 5. compute cer/wer
if self.training or not (self.report_cer or self.report_wer):
cer, wer = 0.0, 0.0
# oracle_cer, oracle_wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad).data
else:
lpz = None
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
nbest_hyps = self.dec.recognize_beam_batch(
hs_pad, torch.tensor(hlens), lpz,
self.recog_args, self.char_list,
self.rnnlm)
# remove <sos> and <eos>
y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps]
for i, y_hat in enumerate(y_hats):
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ')
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '')
seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ')
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_ref_lens.append(len(ref_words))
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_ref_lens.append(len(ref_chars))
wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)
cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)
alpha = self.mtlalpha
if alpha == 0:
self.loss = self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = self.loss_ctc
loss_att_data = None
loss_ctc_data = float(self.loss_ctc)
else:
self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = float(self.loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data)
else:
logging.warning('loss (=%f) is not correct', loss_data)
return self.loss
def scorers(self):
"""Scorers."""
return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: input acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[::self.subsample[0], :]
p = next(self.parameters())
h = torch.as_tensor(x, device=p.device, dtype=p.dtype)
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(hs, ilens)
hs, hlens = self.feature_transform(enhanced, hlens)
else:
hs, hlens = hs, ilens
# 1. encoder
hs, _, _ = self.enc(hs, hlens)
return hs.squeeze(0)
def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
hs = self.encode(x).unsqueeze(0)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs)[0]
else:
lpz = None
# 2. Decoder
# decode the first utterance
y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm)
return y
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[::self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(xs_pad, ilens)
hs_pad, hlens = self.feature_transform(enhanced, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad)
normalize_score = False
else:
lpz = None
normalize_score = True
# 2. Decoder
hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor
y = self.dec.recognize_beam_batch(hs_pad, hlens, lpz, recog_args, char_list,
rnnlm, normalize_score=normalize_score)
if prev:
self.train()
return y
def enhance(self, xs):
"""Forward only in the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
:return: enhaned feature
:rtype: torch.Tensor
"""
if self.frontend is None:
raise RuntimeError('Frontend does\'t exist')
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[::self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
if prev:
self.train()
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hpad, hlens, _ = self.enc(hs_pad, hlens)
# 2. Decoder
att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad)
return att_ws
def subsample_frames(self, x):
"""Subsample speeh frames in the encoder."""
# subsample frame
x = x[::self.subsample[0], :]
ilen = [x.shape[0]]
h = to_device(self, torch.from_numpy(
np.array(x, dtype=np.float32)))
h.contiguous()
return h, ilen
__init__(self, idim, odim, args)
special
¶Construct an E2E object.
:param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def __init__(self, idim, odim, args):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super(E2E, self).__init__() # This loads default arguments,
# but not calling this yields the same error, so it's not why things break.
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
# +1 means input (+1) and layers outputs (args.elayer)
subsample = np.ones(args.elayers + 1, dtype=np.int)
if args.etype.endswith("p") and not args.etype.startswith("vgg"):
ss = args.subsample.split("_")
for j in range(min(args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
self.subsample = subsample
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json)
else:
labeldist = None
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
# Relative importing because of using python3 syntax
from tools.espnet_minimal.nets.pytorch_backend.frontends.feature_transform \
import feature_transform_for
from tools.espnet_minimal.nets.pytorch_backend.frontends.frontend \
import frontend_for
self.frontend = frontend_for(args, idim)
self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
idim = args.n_mels
else:
self.frontend = None
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
# self.ctc = ctc_for(args, odim) <-- if this is executed, the shapes don't match.
# The missing/unexpected arguments are not fixed by removing this however.
# attention
self.att = att_for(args)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
add_arguments(parser)
staticmethod
¶attention_add_arguments(parser)
staticmethod
¶Add arguments for the attention.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
@staticmethod
def attention_add_arguments(parser):
"""Add arguments for the attention."""
group = parser.add_argument_group("E2E attention setting")
# attention
group.add_argument('--atype', default='dot', type=str,
choices=['noatt', 'dot', 'add', 'location', 'coverage',
'coverage_location', 'location2d', 'location_recurrent',
'multi_head_dot', 'multi_head_add', 'multi_head_loc',
'multi_head_multi_res_loc'],
help='Type of attention architecture')
group.add_argument('--adim', default=320, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--awin', default=5, type=int,
help='Window size for location2d attention')
group.add_argument('--aheads', default=4, type=int,
help='Number of heads for multi head attention')
group.add_argument('--aconv-chans', default=-1, type=int,
help='Number of attention convolution channels \
(negative value indicates no location-aware attention)')
group.add_argument('--aconv-filts', default=100, type=int,
help='Number of attention convolution filters \
(negative value indicates no location-aware attention)')
group.add_argument('--dropout-rate', default=0.0, type=float,
help='Dropout rate for the encoder')
return parser
calculate_all_attentions(self, xs_pad, ilens, ys_pad)
¶E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hpad, hlens, _ = self.enc(hs_pad, hlens)
# 2. Decoder
att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad)
return att_ws
decoder_add_arguments(parser)
staticmethod
¶Add arguments for the decoder.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
@staticmethod
def decoder_add_arguments(parser):
"""Add arguments for the decoder."""
group = parser.add_argument_group("E2E encoder setting")
group.add_argument('--dtype', default='lstm', type=str,
choices=['lstm', 'gru'],
help='Type of decoder network architecture')
group.add_argument('--dlayers', default=1, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=320, type=int,
help='Number of decoder hidden units')
group.add_argument('--dropout-rate-decoder', default=0.0, type=float,
help='Dropout rate for the decoder')
group.add_argument('--sampling-probability', default=0.0, type=float,
help='Ratio of predicted labels fed back to decoder')
group.add_argument('--lsm-type', const='', default='', type=str, nargs='?',
choices=['', 'unigram'],
help='Apply label smoothing with a specified distribution type')
return parser
encode(self, x)
¶Encode acoustic features.
:param ndarray x: input acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: input acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[::self.subsample[0], :]
p = next(self.parameters())
h = torch.as_tensor(x, device=p.device, dtype=p.dtype)
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(hs, ilens)
hs, hlens = self.feature_transform(enhanced, hlens)
else:
hs, hlens = hs, ilens
# 1. encoder
hs, _, _ = self.enc(hs, hlens)
return hs.squeeze(0)
encoder_add_arguments(parser)
staticmethod
¶Add arguments for the encoder.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
@staticmethod
def encoder_add_arguments(parser):
"""Add arguments for the encoder."""
group = parser.add_argument_group("E2E encoder setting")
# encoder
group.add_argument('--etype', default='blstmp', type=str,
choices=['lstm', 'blstm', 'lstmp', 'blstmp', 'vgglstmp', 'vggblstmp', 'vgglstm', 'vggblstm',
'gru', 'bgru', 'grup', 'bgrup', 'vgggrup', 'vggbgrup', 'vgggru', 'vggbgru'],
help='Type of encoder network architecture')
group.add_argument('--elayers', default=4, type=int,
help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)')
group.add_argument('--eunits', '-u', default=300, type=int,
help='Number of encoder hidden units')
group.add_argument('--eprojs', default=320, type=int,
help='Number of encoder projection units')
group.add_argument('--subsample', default="1", type=str,
help='Subsample input frames x_y_z means subsample every x frame at 1st layer, '
'every y frame at 2nd layer etc.')
return parser
enhance(self, xs)
¶Forward only in the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F) :return: enhaned feature :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def enhance(self, xs):
"""Forward only in the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
:return: enhaned feature
:rtype: torch.Tensor
"""
if self.frontend is None:
raise RuntimeError('Frontend does\'t exist')
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[::self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
if prev:
self.train()
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
forward(self, xs_pad, ilens, ys_pad)
¶E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: loss value :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# 2. CTC loss
if self.mtlalpha == 0:
self.loss_ctc = None
else:
self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)
# 3. attention loss
if self.mtlalpha == 1:
self.loss_att, acc = None, None
else:
self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc = None
else:
cers = []
y_hats = self.ctc.argmax(hs_pad).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
seq_hat_text = seq_hat_text.replace(self.blank, '')
seq_true_text = "".join(seq_true).replace(self.space, ' ')
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
cer_ctc = sum(cers) / len(cers) if cers else None
# 5. compute cer/wer
if self.training or not (self.report_cer or self.report_wer):
cer, wer = 0.0, 0.0
# oracle_cer, oracle_wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad).data
else:
lpz = None
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
nbest_hyps = self.dec.recognize_beam_batch(
hs_pad, torch.tensor(hlens), lpz,
self.recog_args, self.char_list,
self.rnnlm)
# remove <sos> and <eos>
y_hats = [nbest_hyp[0]['yseq'][1:-1] for nbest_hyp in nbest_hyps]
for i, y_hat in enumerate(y_hats):
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, ' ')
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, '')
seq_true_text = "".join(seq_true).replace(self.recog_args.space, ' ')
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_ref_lens.append(len(ref_words))
hyp_chars = seq_hat_text.replace(' ', '')
ref_chars = seq_true_text.replace(' ', '')
char_ref_lens.append(len(ref_chars))
wer = 0.0 if not self.report_wer else float(sum(word_eds)) / sum(word_ref_lens)
cer = 0.0 if not self.report_cer else float(sum(char_eds)) / sum(char_ref_lens)
alpha = self.mtlalpha
if alpha == 0:
self.loss = self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = self.loss_ctc
loss_att_data = None
loss_ctc_data = float(self.loss_ctc)
else:
self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = float(self.loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data)
else:
logging.warning('loss (=%f) is not correct', loss_data)
return self.loss
init_like_chainer(self)
¶Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in -0.5), b = 0 pytorch basically uses W, b ~ Uniform(-fan_in-0.5, fan_in**-0.5) however, there are two exceptions as far as I know. - EmbedID.W ~ Normal(0, 1) - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for l in six.moves.range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[l].bias_ih)
recognize(self, x, recog_args, char_list, rnnlm=None)
¶E2E beam search.
:param ndarray x: input acoustic feature (T, D) :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
hs = self.encode(x).unsqueeze(0)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs)[0]
else:
lpz = None
# 2. Decoder
# decode the first utterance
y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm)
return y
recognize_batch(self, xs, recog_args, char_list, rnnlm=None)
¶E2E beam search.
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] :param Namespace recog_args: argument Namespace containing options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[::self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(xs_pad, ilens)
hs_pad, hlens = self.feature_transform(enhanced, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad)
normalize_score = False
else:
lpz = None
normalize_score = True
# 2. Decoder
hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor
y = self.dec.recognize_beam_batch(hs_pad, hlens, lpz, recog_args, char_list,
rnnlm, normalize_score=normalize_score)
if prev:
self.train()
return y
scorers(self)
¶subsample_frames(self, x)
¶Subsample speeh frames in the encoder.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr.py
e2e_asr_transformer
¶
Transformer speech recognition model (pytorch).
E2E (ASRInterface, Module)
¶
E2E module.
:param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
group = parser.add_argument_group("transformer model setting")
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help='how to initialize transformer parameters')
group.add_argument("--transformer-input-layer", type=str, default="conv2d",
choices=["conv2d", "linear", "embed"],
help='transformer input layer type')
group.add_argument('--transformer-attn-dropout-rate', default=None, type=float,
help='dropout in transformer attention. use --dropout-rate if None is set')
group.add_argument('--transformer-lr', default=10.0, type=float,
help='Initial value of learning rate')
group.add_argument('--transformer-warmup-steps', default=25000, type=int,
help='optimizer warmup steps')
group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool,
help='normalize loss by length')
group.add_argument('--dropout-rate', default=0.0, type=float,
help='Dropout rate for the encoder')
# Encoder
group.add_argument('--elayers', default=4, type=int,
help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)')
group.add_argument('--eunits', '-u', default=300, type=int,
help='Number of encoder hidden units')
# Attention
group.add_argument('--adim', default=320, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--aheads', default=4, type=int,
help='Number of heads for multi head attention')
# Decoder
group.add_argument('--dlayers', default=1, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=320, type=int,
help='Number of decoder hidden units')
return parser
@property
def attention_plot_class(self):
"""Return PlotAttentionReport."""
pass
def __init__(self, idim, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
torch.nn.Module.__init__(self)
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.dropout_rate
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate
)
self.decoder = Decoder(
odim=odim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
self_attention_dropout_rate=args.transformer_attn_dropout_rate,
src_attention_dropout_rate=args.transformer_attn_dropout_rate
)
self.sos = odim - 1
self.eos = odim - 1
self.odim = odim
self.ignore_id = ignore_id
self.subsample = [1]
# self.lsm_weight = a
self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
args.transformer_length_normalized_loss)
# self.verbose = args.verbose
self.reset_parameters(args)
self.adim = args.adim
self.mtlalpha = args.mtlalpha
if args.mtlalpha > 0.0:
self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
else:
self.ctc = None
if args.report_cer or args.report_wer:
from tools.espnet_minimal import ErrorCalculator
self.error_calculator = ErrorCalculator(args.char_list,
args.sym_space, args.sym_blank,
args.report_cer, args.report_wer)
else:
self.error_calculator = None
self.rnnlm = None
def reset_parameters(self, args):
"""Initialize parameters."""
# initialize parameters
initialize(self, args.transformer_init)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loass value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0:
loss_ctc = None
else:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
# 5. compute cer/wer
if self.training or self.error_calculator is None:
cer, wer = None, None
else:
ys_hat = pred_pad.argmax(dim=-1)
cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
# copyied from e2e_asr
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = loss_ctc
loss_att_data = None
loss_ctc_data = float(loss_ctc)
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data)
else:
logging.warning('loss (=%f) is not correct', loss_data)
return self.loss
def scorers(self):
"""Scorers."""
return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: source acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self.eval()
x = torch.as_tensor(x).unsqueeze(0)
enc_output, _ = self.encoder(x, None)
return enc_output.squeeze(0)
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False):
"""Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D)
:param Namespace recog_args: argment Namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
enc_output = self.encode(x).unsqueeze(0)
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(enc_output)
lpz = lpz.squeeze(0)
else:
lpz = None
h = enc_output.squeeze(0)
logging.info('input lengths: ' + str(h.size(0)))
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
# preprare sos
y = self.sos
vy = h.new_zeros(1).long()
if recog_args.maxlenratio == 0:
maxlen = h.shape[0]
else:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
minlen = int(recog_args.minlenratio * h.size(0))
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y]}
if lpz is not None:
import numpy
from tools.espnet_minimal.nets.ctc_prefix_score import CTCPrefixScore
ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy)
hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
hyp['ctc_score_prev'] = 0.0
if ctc_weight != 1.0:
# pre-pruning based on attention scores
from tools.espnet_minimal.nets.pytorch_backend.rnn.decoders import \
CTC_SCORING_RATIO
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz.shape[-1]
hyps = [hyp]
ended_hyps = []
import six
traced_decoder = None
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
vy.unsqueeze(1)
vy[0] = hyp['yseq'][i]
# get nbest local scores and their ids
ys_mask = subsequent_mask(i + 1).unsqueeze(0)
ys = torch.tensor(hyp['yseq']).unsqueeze(0)
# FIXME: jit does not match non-jit result
if use_jit:
if traced_decoder is None:
traced_decoder = torch.jit.trace(self.decoder.forward_one_step,
(ys, ys_mask, enc_output))
local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0]
else:
local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0]
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1)
ctc_scores, ctc_states = ctc_prefix_score(
hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
+ ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j])
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz is not None:
new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]]
new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypothes: ' + str(len(hyps)))
if char_list is not None:
logging.debug(
'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last postion in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypothes to a final list, and removed them from current hypothes
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
from tools.espnet_minimal.nets.e2e_asr_common import end_detect
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remeined hypothes: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
if char_list is not None:
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheis
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy becasuse Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
return self.recognize(x, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
return nbest_hyps
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
self.forward(xs_pad, ilens, ys_pad)
ret = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
ret[name] = m.attn.cpu().numpy()
return ret
attention_plot_class
property
readonly
¶Return PlotAttentionReport.
__init__(self, idim, odim, args, ignore_id=-1)
special
¶Construct an E2E object.
:param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
def __init__(self, idim, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
torch.nn.Module.__init__(self)
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.dropout_rate
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate
)
self.decoder = Decoder(
odim=odim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
self_attention_dropout_rate=args.transformer_attn_dropout_rate,
src_attention_dropout_rate=args.transformer_attn_dropout_rate
)
self.sos = odim - 1
self.eos = odim - 1
self.odim = odim
self.ignore_id = ignore_id
self.subsample = [1]
# self.lsm_weight = a
self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight,
args.transformer_length_normalized_loss)
# self.verbose = args.verbose
self.reset_parameters(args)
self.adim = args.adim
self.mtlalpha = args.mtlalpha
if args.mtlalpha > 0.0:
self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True)
else:
self.ctc = None
if args.report_cer or args.report_wer:
from tools.espnet_minimal import ErrorCalculator
self.error_calculator = ErrorCalculator(args.char_list,
args.sym_space, args.sym_blank,
args.report_cer, args.report_wer)
else:
self.error_calculator = None
self.rnnlm = None
add_arguments(parser)
staticmethod
¶Add arguments.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
@staticmethod
def add_arguments(parser):
"""Add arguments."""
group = parser.add_argument_group("transformer model setting")
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help='how to initialize transformer parameters')
group.add_argument("--transformer-input-layer", type=str, default="conv2d",
choices=["conv2d", "linear", "embed"],
help='transformer input layer type')
group.add_argument('--transformer-attn-dropout-rate', default=None, type=float,
help='dropout in transformer attention. use --dropout-rate if None is set')
group.add_argument('--transformer-lr', default=10.0, type=float,
help='Initial value of learning rate')
group.add_argument('--transformer-warmup-steps', default=25000, type=int,
help='optimizer warmup steps')
group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool,
help='normalize loss by length')
group.add_argument('--dropout-rate', default=0.0, type=float,
help='Dropout rate for the encoder')
# Encoder
group.add_argument('--elayers', default=4, type=int,
help='Number of encoder layers (for shared recognition part in multi-speaker asr mode)')
group.add_argument('--eunits', '-u', default=300, type=int,
help='Number of encoder hidden units')
# Attention
group.add_argument('--adim', default=320, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--aheads', default=4, type=int,
help='Number of heads for multi head attention')
# Decoder
group.add_argument('--dlayers', default=1, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=320, type=int,
help='Number of decoder hidden units')
return parser
calculate_all_attentions(self, xs_pad, ilens, ys_pad)
¶E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
self.forward(xs_pad, ilens, ys_pad)
ret = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
ret[name] = m.attn.cpu().numpy()
return ret
encode(self, x)
¶Encode acoustic features.
:param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
forward(self, xs_pad, ilens, ys_pad)
¶E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loass value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 1. forward encoder
xs_pad = xs_pad[:, :max(ilens)] # for data parallel
src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2)
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_mask = target_mask(ys_in_pad, self.ignore_id)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad,
ignore_label=self.ignore_id)
# TODO(karita) show predicted text
# TODO(karita) calculate these stats
cer_ctc = None
if self.mtlalpha == 0.0:
loss_ctc = None
else:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
# 5. compute cer/wer
if self.training or self.error_calculator is None:
cer, wer = None, None
else:
ys_hat = pred_pad.argmax(dim=-1)
cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
# copyied from e2e_asr
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = loss_ctc
loss_att_data = None
loss_ctc_data = float(loss_ctc)
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data)
else:
logging.warning('loss (=%f) is not correct', loss_data)
return self.loss
recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False)
¶Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_asr_transformer.py
def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False):
"""Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D)
:param Namespace recog_args: argment Namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
enc_output = self.encode(x).unsqueeze(0)
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(enc_output)
lpz = lpz.squeeze(0)
else:
lpz = None
h = enc_output.squeeze(0)
logging.info('input lengths: ' + str(h.size(0)))
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = recog_args.ctc_weight
# preprare sos
y = self.sos
vy = h.new_zeros(1).long()
if recog_args.maxlenratio == 0:
maxlen = h.shape[0]
else:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * h.size(0)))
minlen = int(recog_args.minlenratio * h.size(0))
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y]}
if lpz is not None:
import numpy
from tools.espnet_minimal.nets.ctc_prefix_score import CTCPrefixScore
ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy)
hyp['ctc_state_prev'] = ctc_prefix_score.initial_state()
hyp['ctc_score_prev'] = 0.0
if ctc_weight != 1.0:
# pre-pruning based on attention scores
from tools.espnet_minimal.nets.pytorch_backend.rnn.decoders import \
CTC_SCORING_RATIO
ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz.shape[-1]
hyps = [hyp]
ended_hyps = []
import six
traced_decoder = None
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
vy.unsqueeze(1)
vy[0] = hyp['yseq'][i]
# get nbest local scores and their ids
ys_mask = subsequent_mask(i + 1).unsqueeze(0)
ys = torch.tensor(hyp['yseq']).unsqueeze(0)
# FIXME: jit does not match non-jit result
if use_jit:
if traced_decoder is None:
traced_decoder = torch.jit.trace(self.decoder.forward_one_step,
(ys, ys_mask, enc_output))
local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0]
else:
local_att_scores = self.decoder.forward_one_step(ys, ys_mask, enc_output)[0]
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1)
ctc_scores, ctc_states = ctc_prefix_score(
hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]] \
+ ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev'])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
new_hyp['score'] = hyp['score'] + float(local_best_scores[0, j])
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz is not None:
new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[0, j]]
new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[0, j]]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypothes: ' + str(len(hyps)))
if char_list is not None:
logging.debug(
'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last postion in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypothes to a final list, and removed them from current hypothes
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
from tools.espnet_minimal.nets.e2e_asr_common import end_detect
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remeined hypothes: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
if char_list is not None:
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypothes: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheis
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy becasuse Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
return self.recognize(x, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
return nbest_hyps
reset_parameters(self, args)
¶scorers(self)
¶
e2e_tts_fastspeech
¶
FastSpeech related modules.
FeedForwardTransformer (TTSInterface, Module)
¶
Feed Forward Transformer for TTS a.k.a. FastSpeech.
This is a module of FastSpeech, feed-forward Transformer with duration predictor described in
FastSpeech: Fast, Robust and Controllable Text to Speech
_, which does not require any auto-regressive
processing during inference, resulting in fast decoding compared with auto-regressive Transformer.
.. _FastSpeech: Fast, Robust and Controllable Text to Speech
:
https://arxiv.org/pdf/1905.09263.pdf
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
class FeedForwardTransformer(TTSInterface, torch.nn.Module):
"""Feed Forward Transformer for TTS a.k.a. FastSpeech.
This is a module of FastSpeech, feed-forward Transformer with duration predictor described in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive
processing during inference, resulting in fast decoding compared with auto-regressive Transformer.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("feed-forward transformer model setting")
# network structure related
group.add_argument("--adim", default=384, type=int,
help="Number of attention transformation dimensions")
group.add_argument("--aheads", default=4, type=int,
help="Number of heads for multi head attention")
group.add_argument("--elayers", default=6, type=int,
help="Number of encoder layers")
group.add_argument("--eunits", default=1536, type=int,
help="Number of encoder hidden units")
group.add_argument("--dlayers", default=6, type=int,
help="Number of decoder layers")
group.add_argument("--dunits", default=1536, type=int,
help="Number of decoder hidden units")
group.add_argument("--positionwise-layer-type", default="linear", type=str,
choices=["linear", "conv1d", "conv1d-linear"],
help="Positionwise layer type.")
group.add_argument("--positionwise-conv-kernel-size", default=3, type=int,
help="Kernel size of positionwise conv1d layer")
group.add_argument("--postnet-layers", default=0, type=int,
help="Number of postnet layers")
group.add_argument("--postnet-chans", default=256, type=int,
help="Number of postnet channels")
group.add_argument("--postnet-filts", default=5, type=int,
help="Filter size of postnet")
group.add_argument("--use-batch-norm", default=True, type=strtobool,
help="Whether to use batch normalization")
group.add_argument("--use-scaled-pos-enc", default=True, type=strtobool,
help="Use trainable scaled positional encoding instead of the fixed scale one")
group.add_argument("--encoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before encoder block")
group.add_argument("--decoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before decoder block")
group.add_argument("--encoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in encoder")
group.add_argument("--decoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in decoder")
group.add_argument("--duration-predictor-layers", default=2, type=int,
help="Number of layers in duration predictor")
group.add_argument("--duration-predictor-chans", default=384, type=int,
help="Number of channels in duration predictor")
group.add_argument("--duration-predictor-kernel-size", default=3, type=int,
help="Kernel size in duration predictor")
group.add_argument("--teacher-model", default=None, type=str, nargs="?",
help="Teacher model file path")
group.add_argument("--reduction-factor", default=1, type=int,
help="Reduction factor")
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spk-embed-integration-type", type=str, default="add",
choices=["add", "concat"],
help="How to integrate speaker embedding")
# training related
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help="How to initialize transformer parameters")
group.add_argument("--initial-encoder-alpha", type=float, default=1.0,
help="Initial alpha value in encoder's ScaledPositionalEncoding")
group.add_argument("--initial-decoder-alpha", type=float, default=1.0,
help="Initial alpha value in decoder's ScaledPositionalEncoding")
group.add_argument("--transformer-lr", default=1.0, type=float,
help="Initial value of learning rate")
group.add_argument("--transformer-warmup-steps", default=4000, type=int,
help="Optimizer warmup steps")
group.add_argument("--transformer-enc-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder except for attention")
group.add_argument("--transformer-enc-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder positional encoding")
group.add_argument("--transformer-enc-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder self-attention")
group.add_argument("--transformer-dec-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder except for attention and pos encoding")
group.add_argument("--transformer-dec-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder positional encoding")
group.add_argument("--transformer-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder self-attention")
group.add_argument("--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder-decoder attention")
group.add_argument("--duration-predictor-dropout-rate", default=0.1, type=float,
help="Dropout rate for duration predictor")
group.add_argument("--postnet-dropout-rate", default=0.5, type=float,
help="Dropout rate in postnet")
group.add_argument("--transfer-encoder-from-teacher", default=True, type=strtobool,
help="Whether to transfer teacher's parameters")
group.add_argument("--transferred-encoder-module", default="all", type=str,
choices=["all", "embed"],
help="Encoder modeules to be trasferred from teacher")
# loss related
group.add_argument("--use-masking", default=True, type=strtobool,
help="Whether to use masking in calculation of loss")
group.add_argument("--use-weighted-masking", default=False, type=strtobool,
help="Whether to use weighted masking in calculation of loss")
return parser
def __init__(self, idim, odim, args=None):
"""Initialize feed-forward Transformer module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- elayers (int): Number of encoder layers.
- eunits (int): Number of encoder hidden units.
- adim (int): Number of attention transformation dimensions.
- aheads (int): Number of heads for multi head attention.
- dlayers (int): Number of decoder layers.
- dunits (int): Number of decoder hidden units.
- use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
- encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
- decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
- encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
- decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
- duration_predictor_layers (int): Number of duration predictor layers.
- duration_predictor_chans (int): Number of duration predictor channels.
- duration_predictor_kernel_size (int): Kernel size of duration predictor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spk_embed_integration_type: How to integrate speaker embedding.
- teacher_model (str): Teacher auto-regressive transformer model path.
- reduction_factor (int): Reduction factor.
- transformer_init (float): How to initialize transformer parameters.
- transformer_lr (float): Initial value of learning rate.
- transformer_warmup_steps (int): Optimizer warmup steps.
- transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
- transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
- transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
- transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
- transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
- transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
- transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters.
- transferred_encoder_module: Encoder module to be initialized using teacher parameters.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.reduction_factor = args.reduction_factor
self.use_scaled_pos_enc = args.use_scaled_pos_enc
self.spk_embed_dim = args.spk_embed_dim
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = args.spk_embed_integration_type
# use idx 0 as padding idx
padding_idx = 0
# get positional encoding class
pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
# define encoder
encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim,
embedding_dim=args.adim,
padding_idx=padding_idx
)
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=encoder_input_layer,
dropout_rate=args.transformer_enc_dropout_rate,
positional_dropout_rate=args.transformer_enc_positional_dropout_rate,
attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.encoder_normalize_before,
concat_after=args.encoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size
)
# define additional projection for speaker embedding
if self.spk_embed_dim is not None:
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim)
else:
self.projection = torch.nn.Linear(args.adim + self.spk_embed_dim, args.adim)
# define duration predictor
self.duration_predictor = DurationPredictor(
idim=args.adim,
n_layers=args.duration_predictor_layers,
n_chans=args.duration_predictor_chans,
kernel_size=args.duration_predictor_kernel_size,
dropout_rate=args.duration_predictor_dropout_rate,
)
# define length regulator
self.length_regulator = LengthRegulator()
# define decoder
# NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
self.decoder = Encoder(
idim=0,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
input_layer=None,
dropout_rate=args.transformer_dec_dropout_rate,
positional_dropout_rate=args.transformer_dec_positional_dropout_rate,
attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.decoder_normalize_before,
concat_after=args.decoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size
)
# define final projection
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor)
# define postnet
self.postnet = None if args.postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=args.postnet_layers,
n_chans=args.postnet_chans,
n_filts=args.postnet_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.postnet_dropout_rate
)
# initialize parameters
self._reset_parameters(init_type=args.transformer_init,
init_enc_alpha=args.initial_encoder_alpha,
init_dec_alpha=args.initial_decoder_alpha)
# define teacher model
if args.teacher_model is not None:
self.teacher = self._load_teacher_model(args.teacher_model)
else:
self.teacher = None
# define duration calculator
if self.teacher is not None:
self.duration_calculator = DurationCalculator(self.teacher)
else:
self.duration_calculator = None
# transfer teacher parameters
if self.teacher is not None and args.transfer_encoder_from_teacher:
self._transfer_from_teacher(args.transferred_encoder_module)
# define criterions
self.criterion = FeedForwardTransformerLoss(
use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking
)
def _forward(self, xs, ilens, ys=None, olens=None, spembs=None, ds=None, is_inference=False):
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# forward duration predictor and length regulator
d_masks = make_pad_mask(ilens).to(xs.device)
if is_inference:
d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax)
hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim)
else:
if ds is None:
with torch.no_grad():
ds = self.duration_calculator(xs, ilens, ys, olens, spembs) # (B, Tmax)
d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax)
hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim)
# forward decoder
if olens is not None:
if self.reduction_factor > 1:
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
h_masks = self._source_mask(olens_in)
else:
h_masks = None
zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim)
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
if is_inference:
return before_outs, after_outs, d_outs
else:
return before_outs, after_outs, ds, d_outs
def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
if extras is not None:
extras = extras[:, :max(ilens)].squeeze(-1)
# forward propagation
before_outs, after_outs, ds, d_outs = self._forward(
xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
# calculate loss
if self.postnet is None:
l1_loss, duration_loss = self.criterion(None, before_outs, d_outs, ys, ds, ilens, olens)
else:
l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens)
loss = l1_loss + duration_loss
report_keys = [
{"l1_loss": l1_loss.item()},
{"duration_loss": duration_loss.item()},
{"loss": loss.item()},
]
# report extra information
if self.use_scaled_pos_enc:
report_keys += [
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()},
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()},
]
self.reporter.report(report_keys)
return loss
def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1).
Returns:
dict: Dict of attention weights and outputs.
"""
with torch.no_grad():
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
if extras is not None:
extras = extras[:, :max(ilens)].squeeze(-1)
# forward propagation
outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)[1]
att_ws_dict = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())]
else:
logging.warning("unknown attention module: " + name)
else:
logging.warning("unknown attention module: " + name)
att_ws_dict[name] = attn
att_ws_dict["predicted_fbank"] = [m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist())]
return att_ws_dict
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace): Dummy for compatibility.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
None: Dummy for compatibility.
None: Dummy for compatibility.
"""
# setup batch axis
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
xs = x.unsqueeze(0)
if spemb is not None:
spembs = spemb.unsqueeze(0)
else:
spembs = None
# inference
_, outs, _ = self._forward(xs, ilens, spembs=spembs, is_inference=True) # (1, L, odim)
return outs[0], None, None
def _integrate_with_spk_embed(self, hs, spembs):
"""Integrate speaker embedding with hidden states.
Args:
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spembs = self.projection(F.normalize(spembs))
hs = hs + spembs.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = self.projection(torch.cat([hs, spembs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
def _source_mask(self, ilens):
"""Make masks for self-attention.
Args:
ilens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0):
# initialize parameters
initialize(self, init_type)
# initialize alpha in scaled positional encoding
if self.use_scaled_pos_enc:
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
def _transfer_from_teacher(self, transferred_encoder_module):
if transferred_encoder_module == "all":
for (n1, p1), (n2, p2) in zip(self.encoder.named_parameters(),
self.teacher.encoder.named_parameters()):
assert n1 == n2, "It seems that encoder structure is different."
assert p1.shape == p2.shape, "It seems that encoder size is different."
p1.data.copy_(p2.data)
elif transferred_encoder_module == "embed":
student_shape = self.encoder.embed[0].weight.data.shape
teacher_shape = self.teacher.encoder.embed[0].weight.data.shape
assert student_shape == teacher_shape, "It seems that embed dimension is different."
self.encoder.embed[0].weight.data.copy_(
self.teacher.encoder.embed[0].weight.data)
else:
raise NotImplementedError("Support only all or embed.")
@property
def attention_plot_class(self):
"""Return plot class for attention weight plot."""
return TTSPlot
@property
def base_plot_keys(self):
"""Return base key names to plot during training. keys should match what `chainer.reporter` reports.
If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.
Returns:
list: List of strings which are base keys to plot during training.
"""
plot_keys = ["loss", "l1_loss", "duration_loss"]
if self.use_scaled_pos_enc:
plot_keys += ["encoder_alpha", "decoder_alpha"]
return plot_keys
attention_plot_class
property
readonly
¶Return plot class for attention weight plot.
base_plot_keys
property
readonly
¶Return base key names to plot during training. keys should match what chainer.reporter
reports.
If you add the key loss
, the reporter will report main/loss
and validation/main/loss
values.
also loss.png
will be created as a figure visulizing main/loss
and validation/main/loss
values.
Returns:
Type | Description |
---|---|
list |
List of strings which are base keys to plot during training. |
__init__(self, idim, odim, args=None)
special
¶Initialize feed-forward Transformer module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
args |
Namespace |
|
None |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def __init__(self, idim, odim, args=None):
"""Initialize feed-forward Transformer module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- elayers (int): Number of encoder layers.
- eunits (int): Number of encoder hidden units.
- adim (int): Number of attention transformation dimensions.
- aheads (int): Number of heads for multi head attention.
- dlayers (int): Number of decoder layers.
- dunits (int): Number of decoder hidden units.
- use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
- encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
- decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
- encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
- decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
- duration_predictor_layers (int): Number of duration predictor layers.
- duration_predictor_chans (int): Number of duration predictor channels.
- duration_predictor_kernel_size (int): Kernel size of duration predictor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spk_embed_integration_type: How to integrate speaker embedding.
- teacher_model (str): Teacher auto-regressive transformer model path.
- reduction_factor (int): Reduction factor.
- transformer_init (float): How to initialize transformer parameters.
- transformer_lr (float): Initial value of learning rate.
- transformer_warmup_steps (int): Optimizer warmup steps.
- transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
- transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
- transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
- transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
- transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
- transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
- transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters.
- transferred_encoder_module: Encoder module to be initialized using teacher parameters.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.reduction_factor = args.reduction_factor
self.use_scaled_pos_enc = args.use_scaled_pos_enc
self.spk_embed_dim = args.spk_embed_dim
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = args.spk_embed_integration_type
# use idx 0 as padding idx
padding_idx = 0
# get positional encoding class
pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
# define encoder
encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim,
embedding_dim=args.adim,
padding_idx=padding_idx
)
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=encoder_input_layer,
dropout_rate=args.transformer_enc_dropout_rate,
positional_dropout_rate=args.transformer_enc_positional_dropout_rate,
attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.encoder_normalize_before,
concat_after=args.encoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size
)
# define additional projection for speaker embedding
if self.spk_embed_dim is not None:
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim)
else:
self.projection = torch.nn.Linear(args.adim + self.spk_embed_dim, args.adim)
# define duration predictor
self.duration_predictor = DurationPredictor(
idim=args.adim,
n_layers=args.duration_predictor_layers,
n_chans=args.duration_predictor_chans,
kernel_size=args.duration_predictor_kernel_size,
dropout_rate=args.duration_predictor_dropout_rate,
)
# define length regulator
self.length_regulator = LengthRegulator()
# define decoder
# NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
self.decoder = Encoder(
idim=0,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
input_layer=None,
dropout_rate=args.transformer_dec_dropout_rate,
positional_dropout_rate=args.transformer_dec_positional_dropout_rate,
attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.decoder_normalize_before,
concat_after=args.decoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size
)
# define final projection
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor)
# define postnet
self.postnet = None if args.postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=args.postnet_layers,
n_chans=args.postnet_chans,
n_filts=args.postnet_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.postnet_dropout_rate
)
# initialize parameters
self._reset_parameters(init_type=args.transformer_init,
init_enc_alpha=args.initial_encoder_alpha,
init_dec_alpha=args.initial_decoder_alpha)
# define teacher model
if args.teacher_model is not None:
self.teacher = self._load_teacher_model(args.teacher_model)
else:
self.teacher = None
# define duration calculator
if self.teacher is not None:
self.duration_calculator = DurationCalculator(self.teacher)
else:
self.duration_calculator = None
# transfer teacher parameters
if self.teacher is not None and args.transfer_encoder_from_teacher:
self._transfer_from_teacher(args.transferred_encoder_module)
# define criterions
self.criterion = FeedForwardTransformerLoss(
use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking
)
add_arguments(parser)
staticmethod
¶Add model-specific arguments to the parser.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("feed-forward transformer model setting")
# network structure related
group.add_argument("--adim", default=384, type=int,
help="Number of attention transformation dimensions")
group.add_argument("--aheads", default=4, type=int,
help="Number of heads for multi head attention")
group.add_argument("--elayers", default=6, type=int,
help="Number of encoder layers")
group.add_argument("--eunits", default=1536, type=int,
help="Number of encoder hidden units")
group.add_argument("--dlayers", default=6, type=int,
help="Number of decoder layers")
group.add_argument("--dunits", default=1536, type=int,
help="Number of decoder hidden units")
group.add_argument("--positionwise-layer-type", default="linear", type=str,
choices=["linear", "conv1d", "conv1d-linear"],
help="Positionwise layer type.")
group.add_argument("--positionwise-conv-kernel-size", default=3, type=int,
help="Kernel size of positionwise conv1d layer")
group.add_argument("--postnet-layers", default=0, type=int,
help="Number of postnet layers")
group.add_argument("--postnet-chans", default=256, type=int,
help="Number of postnet channels")
group.add_argument("--postnet-filts", default=5, type=int,
help="Filter size of postnet")
group.add_argument("--use-batch-norm", default=True, type=strtobool,
help="Whether to use batch normalization")
group.add_argument("--use-scaled-pos-enc", default=True, type=strtobool,
help="Use trainable scaled positional encoding instead of the fixed scale one")
group.add_argument("--encoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before encoder block")
group.add_argument("--decoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before decoder block")
group.add_argument("--encoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in encoder")
group.add_argument("--decoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in decoder")
group.add_argument("--duration-predictor-layers", default=2, type=int,
help="Number of layers in duration predictor")
group.add_argument("--duration-predictor-chans", default=384, type=int,
help="Number of channels in duration predictor")
group.add_argument("--duration-predictor-kernel-size", default=3, type=int,
help="Kernel size in duration predictor")
group.add_argument("--teacher-model", default=None, type=str, nargs="?",
help="Teacher model file path")
group.add_argument("--reduction-factor", default=1, type=int,
help="Reduction factor")
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spk-embed-integration-type", type=str, default="add",
choices=["add", "concat"],
help="How to integrate speaker embedding")
# training related
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help="How to initialize transformer parameters")
group.add_argument("--initial-encoder-alpha", type=float, default=1.0,
help="Initial alpha value in encoder's ScaledPositionalEncoding")
group.add_argument("--initial-decoder-alpha", type=float, default=1.0,
help="Initial alpha value in decoder's ScaledPositionalEncoding")
group.add_argument("--transformer-lr", default=1.0, type=float,
help="Initial value of learning rate")
group.add_argument("--transformer-warmup-steps", default=4000, type=int,
help="Optimizer warmup steps")
group.add_argument("--transformer-enc-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder except for attention")
group.add_argument("--transformer-enc-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder positional encoding")
group.add_argument("--transformer-enc-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder self-attention")
group.add_argument("--transformer-dec-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder except for attention and pos encoding")
group.add_argument("--transformer-dec-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder positional encoding")
group.add_argument("--transformer-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder self-attention")
group.add_argument("--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder-decoder attention")
group.add_argument("--duration-predictor-dropout-rate", default=0.1, type=float,
help="Dropout rate for duration predictor")
group.add_argument("--postnet-dropout-rate", default=0.5, type=float,
help="Dropout rate in postnet")
group.add_argument("--transfer-encoder-from-teacher", default=True, type=strtobool,
help="Whether to transfer teacher's parameters")
group.add_argument("--transferred-encoder-module", default="all", type=str,
choices=["all", "embed"],
help="Encoder modeules to be trasferred from teacher")
# loss related
group.add_argument("--use-masking", default=True, type=strtobool,
help="Whether to use masking in calculation of loss")
group.add_argument("--use-weighted-masking", default=False, type=strtobool,
help="Whether to use weighted masking in calculation of loss")
return parser
calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs)
¶Calculate all of the attention weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
extras |
Tensor |
Batch of precalculated durations (B, Tmax, 1). |
None |
Returns:
Type | Description |
---|---|
dict |
Dict of attention weights and outputs. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1).
Returns:
dict: Dict of attention weights and outputs.
"""
with torch.no_grad():
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
if extras is not None:
extras = extras[:, :max(ilens)].squeeze(-1)
# forward propagation
outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)[1]
att_ws_dict = dict()
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())]
else:
logging.warning("unknown attention module: " + name)
else:
logging.warning("unknown attention module: " + name)
att_ws_dict[name] = attn
att_ws_dict["predicted_fbank"] = [m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist())]
return att_ws_dict
forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
extras |
Tensor |
Batch of precalculated durations (B, Tmax, 1). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
xs = xs[:, :max(ilens)]
ys = ys[:, :max(olens)]
if extras is not None:
extras = extras[:, :max(ilens)].squeeze(-1)
# forward propagation
before_outs, after_outs, ds, d_outs = self._forward(
xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
# calculate loss
if self.postnet is None:
l1_loss, duration_loss = self.criterion(None, before_outs, d_outs, ys, ds, ilens, olens)
else:
l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens)
loss = l1_loss + duration_loss
report_keys = [
{"l1_loss": l1_loss.item()},
{"duration_loss": duration_loss.item()},
{"loss": loss.item()},
]
# report extra information
if self.use_scaled_pos_enc:
report_keys += [
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()},
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()},
]
self.reporter.report(report_keys)
return loss
inference(self, x, inference_args, spemb=None, *args, **kwargs)
¶Generate the sequence of features given the sequences of characters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Input sequence of characters (T,). |
required |
inference_args |
Namespace |
Dummy for compatibility. |
required |
spemb |
Tensor |
Speaker embedding vector (spk_embed_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace): Dummy for compatibility.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
None: Dummy for compatibility.
None: Dummy for compatibility.
"""
# setup batch axis
ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
xs = x.unsqueeze(0)
if spemb is not None:
spembs = spemb.unsqueeze(0)
else:
spembs = None
# inference
_, outs, _ = self._forward(xs, ilens, spembs=spembs, is_inference=True) # (1, L, odim)
return outs[0], None, None
FeedForwardTransformerLoss (Module)
¶
Loss function module for feed-forward Transformer.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
class FeedForwardTransformerLoss(torch.nn.Module):
"""Loss function module for feed-forward Transformer."""
def __init__(self, use_masking=True, use_weighted_masking=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool): Whether to weighted masking in loss calculation.
"""
super(FeedForwardTransformerLoss, self).__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens):
"""Calculate forward propagation.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
d_outs (Tensor): Batch of outputs of duration predictor (B, Tmax).
ys (Tensor): Batch of target features (B, Lmax, odim).
ds (Tensor): Batch of durations (B, Tmax).
ilens (LongTensor): Batch of the lengths of each input (B,).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Duration predictor loss value.
"""
# apply mask to remove padded part
if self.use_masking:
duration_masks = make_non_pad_mask(ilens).to(ys.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
before_outs = before_outs.masked_select(out_masks)
after_outs = after_outs.masked_select(out_masks) if after_outs is not None else None
ys = ys.masked_select(out_masks)
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
l1_loss += self.l1_criterion(after_outs, ys)
duration_loss = self.duration_criterion(d_outs, ds)
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
out_weights /= ys.size(0) * ys.size(2)
duration_masks = make_non_pad_mask(ilens).to(ys.device)
duration_weights = duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()
duration_weights /= ds.size(0)
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
duration_loss = duration_loss.mul(duration_weights).masked_select(duration_masks).sum()
return l1_loss, duration_loss
__init__(self, use_masking=True, use_weighted_masking=False)
special
¶Initialize feed-forward Transformer loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
use_masking |
bool |
Whether to apply masking for padded part in loss calculation. |
True |
use_weighted_masking |
bool |
Whether to weighted masking in loss calculation. |
False |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def __init__(self, use_masking=True, use_weighted_masking=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool): Whether to weighted masking in loss calculation.
"""
super(FeedForwardTransformerLoss, self).__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
after_outs |
Tensor |
Batch of outputs after postnets (B, Lmax, odim). |
required |
before_outs |
Tensor |
Batch of outputs before postnets (B, Lmax, odim). |
required |
d_outs |
Tensor |
Batch of outputs of duration predictor (B, Tmax). |
required |
ys |
Tensor |
Batch of target features (B, Lmax, odim). |
required |
ds |
Tensor |
Batch of durations (B, Tmax). |
required |
ilens |
LongTensor |
Batch of the lengths of each input (B,). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
L1 loss value. Tensor: Duration predictor loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_fastspeech.py
def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens):
"""Calculate forward propagation.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
d_outs (Tensor): Batch of outputs of duration predictor (B, Tmax).
ys (Tensor): Batch of target features (B, Lmax, odim).
ds (Tensor): Batch of durations (B, Tmax).
ilens (LongTensor): Batch of the lengths of each input (B,).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Duration predictor loss value.
"""
# apply mask to remove padded part
if self.use_masking:
duration_masks = make_non_pad_mask(ilens).to(ys.device)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
before_outs = before_outs.masked_select(out_masks)
after_outs = after_outs.masked_select(out_masks) if after_outs is not None else None
ys = ys.masked_select(out_masks)
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
if after_outs is not None:
l1_loss += self.l1_criterion(after_outs, ys)
duration_loss = self.duration_criterion(d_outs, ds)
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
out_weights /= ys.size(0) * ys.size(2)
duration_masks = make_non_pad_mask(ilens).to(ys.device)
duration_weights = duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()
duration_weights /= ds.size(0)
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
duration_loss = duration_loss.mul(duration_weights).masked_select(duration_masks).sum()
return l1_loss, duration_loss
e2e_tts_tacotron2
¶
Tacotron 2 related modules.
GuidedAttentionLoss (Module)
¶
Guided attention loss function module.
This module calculates the guided attention loss described in Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention
_, which forces the attention to be diagonal.
.. _Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention
:
https://arxiv.org/abs/1710.08969
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
class GuidedAttentionLoss(torch.nn.Module):
"""Guided attention loss function module.
This module calculates the guided attention loss described in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_, which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969
"""
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Args:
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
super(GuidedAttentionLoss, self).__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
def _reset_masks(self):
self.guided_attn_masks = None
self.masks = None
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
ilens (LongTensor): Batch of input lenghts (B,).
olens (LongTensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
if self.masks is None:
self.masks = self._make_masks(ilens, olens).to(att_ws.device)
losses = self.guided_attn_masks * att_ws
loss = torch.mean(losses.masked_select(self.masks))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens):
n_batches = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma)
return guided_attn_masks
@staticmethod
def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask.
Examples:
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape
torch.Size([5, 5])
>>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape
torch.Size([6, 3])
>>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422],
[0.2934, 0.0000, 0.2934],
[0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]])
"""
grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen) ** 2 / (2 * (sigma ** 2)))
@staticmethod
def _make_masks(ilens, olens):
"""Make masks indicating non-padded part.
Args:
ilens (LongTensor or List): Batch of lengths (B,).
olens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor indicating non-padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
in_masks = make_non_pad_mask(ilens) # (B, T_in)
out_masks = make_non_pad_mask(olens) # (B, T_out)
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
__init__(self, sigma=0.4, alpha=1.0, reset_always=True)
special
¶Initialize guided attention loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sigma |
float |
Standard deviation to control how close attention to a diagonal. |
0.4 |
alpha |
float |
Scaling coefficient (lambda). |
1.0 |
reset_always |
bool |
Whether to always reset masks. |
True |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Args:
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
super(GuidedAttentionLoss, self).__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
forward(self, att_ws, ilens, olens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
att_ws |
Tensor |
Batch of attention weights (B, T_max_out, T_max_in). |
required |
ilens |
LongTensor |
Batch of input lenghts (B,). |
required |
olens |
LongTensor |
Batch of output lenghts (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Guided attention loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in).
ilens (LongTensor): Batch of input lenghts (B,).
olens (LongTensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device)
if self.masks is None:
self.masks = self._make_masks(ilens, olens).to(att_ws.device)
losses = self.guided_attn_masks * att_ws
loss = torch.mean(losses.masked_select(self.masks))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
Tacotron2 (TTSInterface, Module)
¶
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
This is a module of Spectrogram prediction network in Tacotron2 described in Natural TTS Synthesis
by Conditioning WaveNet on Mel Spectrogram Predictions
_, which converts the sequence of characters
into the sequence of Mel-filterbanks.
.. _Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
:
https://arxiv.org/abs/1712.05884
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
class Tacotron2(TTSInterface, torch.nn.Module):
"""Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis
by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters
into the sequence of Mel-filterbanks.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("tacotron 2 model setting")
# encoder
group.add_argument('--embed-dim', default=512, type=int,
help='Number of dimension of embedding')
group.add_argument('--elayers', default=1, type=int,
help='Number of encoder layers')
group.add_argument('--eunits', '-u', default=512, type=int,
help='Number of encoder hidden units')
group.add_argument('--econv-layers', default=3, type=int,
help='Number of encoder convolution layers')
group.add_argument('--econv-chans', default=512, type=int,
help='Number of encoder convolution channels')
group.add_argument('--econv-filts', default=5, type=int,
help='Filter size of encoder convolution')
# attention
group.add_argument('--atype', default="location", type=str,
choices=["forward_ta", "forward", "location"],
help='Type of attention mechanism')
group.add_argument('--adim', default=512, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--aconv-chans', default=32, type=int,
help='Number of attention convolution channels')
group.add_argument('--aconv-filts', default=15, type=int,
help='Filter size of attention convolution')
group.add_argument('--cumulate-att-w', default=True, type=strtobool,
help="Whether or not to cumulate attention weights")
# decoder
group.add_argument('--dlayers', default=2, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=1024, type=int,
help='Number of decoder hidden units')
group.add_argument('--prenet-layers', default=2, type=int,
help='Number of prenet layers')
group.add_argument('--prenet-units', default=256, type=int,
help='Number of prenet hidden units')
group.add_argument('--postnet-layers', default=5, type=int,
help='Number of postnet layers')
group.add_argument('--postnet-chans', default=512, type=int,
help='Number of postnet channels')
group.add_argument('--postnet-filts', default=5, type=int,
help='Filter size of postnet')
group.add_argument('--output-activation', default=None, type=str, nargs='?',
help='Output activation function')
# cbhg
group.add_argument('--use-cbhg', default=False, type=strtobool,
help='Whether to use CBHG module')
group.add_argument('--cbhg-conv-bank-layers', default=8, type=int,
help='Number of convoluional bank layers in CBHG')
group.add_argument('--cbhg-conv-bank-chans', default=128, type=int,
help='Number of convoluional bank channles in CBHG')
group.add_argument('--cbhg-conv-proj-filts', default=3, type=int,
help='Filter size of convoluional projection layer in CBHG')
group.add_argument('--cbhg-conv-proj-chans', default=256, type=int,
help='Number of convoluional projection channels in CBHG')
group.add_argument('--cbhg-highway-layers', default=4, type=int,
help='Number of highway layers in CBHG')
group.add_argument('--cbhg-highway-units', default=128, type=int,
help='Number of highway units in CBHG')
group.add_argument('--cbhg-gru-units', default=256, type=int,
help='Number of GRU units in CBHG')
# model (parameter) related
group.add_argument('--use-batch-norm', default=True, type=strtobool,
help='Whether to use batch normalization')
group.add_argument('--use-concate', default=True, type=strtobool,
help='Whether to concatenate encoder embedding with decoder outputs')
group.add_argument('--use-residual', default=True, type=strtobool,
help='Whether to use residual connection in conv layer')
group.add_argument('--dropout-rate', default=0.5, type=float,
help='Dropout rate')
group.add_argument('--zoneout-rate', default=0.1, type=float,
help='Zoneout rate')
group.add_argument('--reduction-factor', default=1, type=int,
help='Reduction factor')
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spc-dim", default=None, type=int,
help="Number of spectrogram dimensions")
group.add_argument("--pretrained-model", default=None, type=str,
help="Pretrained model path")
# loss related
group.add_argument('--use-masking', default=False, type=strtobool,
help='Whether to use masking in calculation of loss')
group.add_argument('--use-weighted-masking', default=False, type=strtobool,
help='Whether to use weighted masking in calculation of loss')
group.add_argument('--bce-pos-weight', default=20.0, type=float,
help='Positive sample weight in BCE calculation (only for use-masking=True)')
group.add_argument("--use-guided-attn-loss", default=False, type=strtobool,
help="Whether to use guided attention loss")
group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float,
help="Sigma in guided attention loss")
group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float,
help="Lambda in guided attention loss")
return parser
def __init__(self, idim, odim, args=None):
"""Initialize Tacotron2 module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- spk_embed_dim (int): Dimension of the speaker embedding.
- embed_dim (int): Dimension of character embedding.
- elayers (int): The number of encoder blstm layers.
- eunits (int): The number of encoder blstm units.
- econv_layers (int): The number of encoder conv layers.
- econv_filts (int): The number of encoder conv filter size.
- econv_chans (int): The number of encoder conv filter channels.
- dlayers (int): The number of decoder lstm layers.
- dunits (int): The number of decoder lstm units.
- prenet_layers (int): The number of prenet layers.
- prenet_units (int): The number of prenet units.
- postnet_layers (int): The number of postnet layers.
- postnet_filts (int): The number of postnet filter size.
- postnet_chans (int): The number of postnet filter channels.
- output_activation (int): The name of activation function for outputs.
- adim (int): The number of dimension of mlp in attention.
- aconv_chans (int): The number of attention conv filter channels.
- aconv_filts (int): The number of attention conv filter size.
- cumulate_att_w (bool): Whether to cumulate previous attention weight.
- use_batch_norm (bool): Whether to use batch normalization.
- use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs.
- dropout_rate (float): Dropout rate.
- zoneout_rate (float): Zoneout rate.
- reduction_factor (int): Reduction factor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True).
- use_cbhg (bool): Whether to use CBHG module.
- cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
- cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG.
- cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG.
- cbhg_proj_chans (int): The number of channels of projection layer in CBHG.
- cbhg_highway_layers (int): The number of layers of highway network in CBHG.
- cbhg_highway_units (int): The number of units of highway network in CBHG.
- cbhg_gru_units (int): The number of units of GRU in CBHG.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True).
- use-guided-attn-loss (bool): Whether to use guided attention loss.
- guided-attn-loss-sigma (float) Sigma in guided attention loss.
- guided-attn-loss-lamdba (float): Lambda in guided attention loss.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.spk_embed_dim = args.spk_embed_dim
self.cumulate_att_w = args.cumulate_att_w
self.reduction_factor = args.reduction_factor
self.use_cbhg = args.use_cbhg
self.use_guided_attn_loss = args.use_guided_attn_loss
# define activation function for the final output
if args.output_activation is None:
self.output_activation_fn = None
elif hasattr(F, args.output_activation):
self.output_activation_fn = getattr(F, args.output_activation)
else:
raise ValueError('there is no such an activation function. (%s)' % args.output_activation)
# set padding idx
padding_idx = 0
# define network modules
self.enc = Encoder(idim=idim,
embed_dim=args.embed_dim,
elayers=args.elayers,
eunits=args.eunits,
econv_layers=args.econv_layers,
econv_chans=args.econv_chans,
econv_filts=args.econv_filts,
use_batch_norm=args.use_batch_norm,
use_residual=args.use_residual,
dropout_rate=args.dropout_rate,
padding_idx=padding_idx)
dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
if args.atype == "location":
att = AttLoc(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts)
elif args.atype == "forward":
att = AttForward(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts)
if self.cumulate_att_w:
logging.warning("cumulation of attention weights is disabled in forward attention.")
self.cumulate_att_w = False
elif args.atype == "forward_ta":
att = AttForwardTA(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts,
odim)
if self.cumulate_att_w:
logging.warning("cumulation of attention weights is disabled in forward attention.")
self.cumulate_att_w = False
else:
raise NotImplementedError("Support only location or forward")
self.dec = Decoder(idim=dec_idim,
odim=odim,
att=att,
dlayers=args.dlayers,
dunits=args.dunits,
prenet_layers=args.prenet_layers,
prenet_units=args.prenet_units,
postnet_layers=args.postnet_layers,
postnet_chans=args.postnet_chans,
postnet_filts=args.postnet_filts,
output_activation_fn=self.output_activation_fn,
cumulate_att_w=self.cumulate_att_w,
use_batch_norm=args.use_batch_norm,
use_concate=args.use_concate,
dropout_rate=args.dropout_rate,
zoneout_rate=args.zoneout_rate,
reduction_factor=args.reduction_factor)
self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking,
bce_pos_weight=args.bce_pos_weight)
if self.use_guided_attn_loss:
self.attn_loss = GuidedAttentionLoss(
sigma=args.guided_attn_loss_sigma,
alpha=args.guided_attn_loss_lambda,
)
if self.use_cbhg:
self.cbhg = CBHG(idim=odim,
odim=args.spc_dim,
conv_bank_layers=args.cbhg_conv_bank_layers,
conv_bank_chans=args.cbhg_conv_bank_chans,
conv_proj_filts=args.cbhg_conv_proj_filts,
conv_proj_chans=args.cbhg_conv_proj_chans,
highway_layers=args.cbhg_highway_layers,
highway_units=args.cbhg_highway_units,
gru_units=args.cbhg_gru_units)
self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
# load pretrained model
if args.pretrained_model is not None:
self.load_pretrained_model(args.pretrained_model)
def forward(self, xs, ilens, ys, labels, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
max_in = max(ilens)
max_out = max(olens)
if max_in != xs.shape[1]:
xs = xs[:, :max_in]
if max_out != ys.shape[1]:
ys = ys[:, :max_out]
labels = labels[:, :max_out]
# calculate tacotron2 outputs
hs, hlens = self.enc(xs, ilens)
if self.spk_embed_dim is not None:
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = torch.cat([hs, spembs], dim=-1)
after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_out = max(olens)
ys = ys[:, :max_out]
labels = labels[:, :max_out]
labels[:, -1] = 1.0 # make sure at least one frame has 1
# caluculate taco2 loss
l1_loss, mse_loss, bce_loss = self.taco2_loss(
after_outs, before_outs, logits, ys, labels, olens)
loss = l1_loss + mse_loss + bce_loss
report_keys = [
{'l1_loss': l1_loss.item()},
{'mse_loss': mse_loss.item()},
{'bce_loss': bce_loss.item()},
]
# caluculate attention loss
if self.use_guided_attn_loss:
# NOTE(kan-bayashi): length of output for auto-regressive input will be changed when r > 1
if self.reduction_factor > 1:
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
loss = loss + attn_loss
report_keys += [
{'attn_loss': attn_loss.item()},
]
# caluculate cbhg loss
if self.use_cbhg:
# remove unnecessary padded part (for multi-gpus)
if max_out != extras.shape[1]:
extras = extras[:, :max_out]
# caluculate cbhg outputs & loss and report them
cbhg_outs, _ = self.cbhg(after_outs, olens)
cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, extras, olens)
loss = loss + cbhg_l1_loss + cbhg_mse_loss
report_keys += [
{'cbhg_l1_loss': cbhg_l1_loss.item()},
{'cbhg_mse_loss': cbhg_mse_loss.item()},
]
report_keys += [{'loss': loss.item()}]
self.reporter.report(report_keys)
return loss
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace):
- threshold (float): Threshold in inference.
- minlenratio (float): Minimum length ratio in inference.
- maxlenratio (float): Maximum length ratio in inference.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
"""
# get options
threshold = inference_args.threshold
minlenratio = inference_args.minlenratio
maxlenratio = inference_args.maxlenratio
use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility
backward_window = inference_args.backward_window if use_att_constraint else 0
forward_window = inference_args.forward_window if use_att_constraint else 0
# inference
h = self.enc.inference(x)
if self.spk_embed_dim is not None:
spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1)
h = torch.cat([h, spemb], dim=-1)
outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio,
use_att_constraint=use_att_constraint,
backward_window=backward_window,
forward_window=forward_window)
if self.use_cbhg:
cbhg_outs = self.cbhg.inference(outs)
return cbhg_outs, probs, att_ws
else:
return outs, probs, att_ws
def calculate_all_attentions(self, xs, ilens, ys, spembs=None, keep_tensor=False, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
keep_tensor (bool, optional): Whether to keep original tensor.
Returns:
Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax).
"""
# check ilens type (should be list of int)
if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
ilens = list(map(int, ilens))
self.eval()
with torch.no_grad():
hs, hlens = self.enc(xs, ilens)
if self.spk_embed_dim is not None:
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = torch.cat([hs, spembs], dim=-1)
att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
self.train()
if keep_tensor:
return att_ws
else:
return att_ws.cpu().numpy()
@property
def base_plot_keys(self):
"""Return base key names to plot during training. keys should match what `chainer.reporter` reports.
If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.
Returns:
list: List of strings which are base keys to plot during training.
"""
plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss']
if self.use_guided_attn_loss:
plot_keys += ['attn_loss']
if self.use_cbhg:
plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss']
return plot_keys
base_plot_keys
property
readonly
¶Return base key names to plot during training. keys should match what chainer.reporter
reports.
If you add the key loss
, the reporter will report main/loss
and validation/main/loss
values.
also loss.png
will be created as a figure visulizing main/loss
and validation/main/loss
values.
Returns:
Type | Description |
---|---|
list |
List of strings which are base keys to plot during training. |
__init__(self, idim, odim, args=None)
special
¶Initialize Tacotron2 module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
args |
Namespace |
|
None |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def __init__(self, idim, odim, args=None):
"""Initialize Tacotron2 module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- spk_embed_dim (int): Dimension of the speaker embedding.
- embed_dim (int): Dimension of character embedding.
- elayers (int): The number of encoder blstm layers.
- eunits (int): The number of encoder blstm units.
- econv_layers (int): The number of encoder conv layers.
- econv_filts (int): The number of encoder conv filter size.
- econv_chans (int): The number of encoder conv filter channels.
- dlayers (int): The number of decoder lstm layers.
- dunits (int): The number of decoder lstm units.
- prenet_layers (int): The number of prenet layers.
- prenet_units (int): The number of prenet units.
- postnet_layers (int): The number of postnet layers.
- postnet_filts (int): The number of postnet filter size.
- postnet_chans (int): The number of postnet filter channels.
- output_activation (int): The name of activation function for outputs.
- adim (int): The number of dimension of mlp in attention.
- aconv_chans (int): The number of attention conv filter channels.
- aconv_filts (int): The number of attention conv filter size.
- cumulate_att_w (bool): Whether to cumulate previous attention weight.
- use_batch_norm (bool): Whether to use batch normalization.
- use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs.
- dropout_rate (float): Dropout rate.
- zoneout_rate (float): Zoneout rate.
- reduction_factor (int): Reduction factor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True).
- use_cbhg (bool): Whether to use CBHG module.
- cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
- cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG.
- cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG.
- cbhg_proj_chans (int): The number of channels of projection layer in CBHG.
- cbhg_highway_layers (int): The number of layers of highway network in CBHG.
- cbhg_highway_units (int): The number of units of highway network in CBHG.
- cbhg_gru_units (int): The number of units of GRU in CBHG.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True).
- use-guided-attn-loss (bool): Whether to use guided attention loss.
- guided-attn-loss-sigma (float) Sigma in guided attention loss.
- guided-attn-loss-lamdba (float): Lambda in guided attention loss.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.spk_embed_dim = args.spk_embed_dim
self.cumulate_att_w = args.cumulate_att_w
self.reduction_factor = args.reduction_factor
self.use_cbhg = args.use_cbhg
self.use_guided_attn_loss = args.use_guided_attn_loss
# define activation function for the final output
if args.output_activation is None:
self.output_activation_fn = None
elif hasattr(F, args.output_activation):
self.output_activation_fn = getattr(F, args.output_activation)
else:
raise ValueError('there is no such an activation function. (%s)' % args.output_activation)
# set padding idx
padding_idx = 0
# define network modules
self.enc = Encoder(idim=idim,
embed_dim=args.embed_dim,
elayers=args.elayers,
eunits=args.eunits,
econv_layers=args.econv_layers,
econv_chans=args.econv_chans,
econv_filts=args.econv_filts,
use_batch_norm=args.use_batch_norm,
use_residual=args.use_residual,
dropout_rate=args.dropout_rate,
padding_idx=padding_idx)
dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
if args.atype == "location":
att = AttLoc(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts)
elif args.atype == "forward":
att = AttForward(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts)
if self.cumulate_att_w:
logging.warning("cumulation of attention weights is disabled in forward attention.")
self.cumulate_att_w = False
elif args.atype == "forward_ta":
att = AttForwardTA(dec_idim,
args.dunits,
args.adim,
args.aconv_chans,
args.aconv_filts,
odim)
if self.cumulate_att_w:
logging.warning("cumulation of attention weights is disabled in forward attention.")
self.cumulate_att_w = False
else:
raise NotImplementedError("Support only location or forward")
self.dec = Decoder(idim=dec_idim,
odim=odim,
att=att,
dlayers=args.dlayers,
dunits=args.dunits,
prenet_layers=args.prenet_layers,
prenet_units=args.prenet_units,
postnet_layers=args.postnet_layers,
postnet_chans=args.postnet_chans,
postnet_filts=args.postnet_filts,
output_activation_fn=self.output_activation_fn,
cumulate_att_w=self.cumulate_att_w,
use_batch_norm=args.use_batch_norm,
use_concate=args.use_concate,
dropout_rate=args.dropout_rate,
zoneout_rate=args.zoneout_rate,
reduction_factor=args.reduction_factor)
self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking,
bce_pos_weight=args.bce_pos_weight)
if self.use_guided_attn_loss:
self.attn_loss = GuidedAttentionLoss(
sigma=args.guided_attn_loss_sigma,
alpha=args.guided_attn_loss_lambda,
)
if self.use_cbhg:
self.cbhg = CBHG(idim=odim,
odim=args.spc_dim,
conv_bank_layers=args.cbhg_conv_bank_layers,
conv_bank_chans=args.cbhg_conv_bank_chans,
conv_proj_filts=args.cbhg_conv_proj_filts,
conv_proj_chans=args.cbhg_conv_proj_chans,
highway_layers=args.cbhg_highway_layers,
highway_units=args.cbhg_highway_units,
gru_units=args.cbhg_gru_units)
self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
# load pretrained model
if args.pretrained_model is not None:
self.load_pretrained_model(args.pretrained_model)
add_arguments(parser)
staticmethod
¶Add model-specific arguments to the parser.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("tacotron 2 model setting")
# encoder
group.add_argument('--embed-dim', default=512, type=int,
help='Number of dimension of embedding')
group.add_argument('--elayers', default=1, type=int,
help='Number of encoder layers')
group.add_argument('--eunits', '-u', default=512, type=int,
help='Number of encoder hidden units')
group.add_argument('--econv-layers', default=3, type=int,
help='Number of encoder convolution layers')
group.add_argument('--econv-chans', default=512, type=int,
help='Number of encoder convolution channels')
group.add_argument('--econv-filts', default=5, type=int,
help='Filter size of encoder convolution')
# attention
group.add_argument('--atype', default="location", type=str,
choices=["forward_ta", "forward", "location"],
help='Type of attention mechanism')
group.add_argument('--adim', default=512, type=int,
help='Number of attention transformation dimensions')
group.add_argument('--aconv-chans', default=32, type=int,
help='Number of attention convolution channels')
group.add_argument('--aconv-filts', default=15, type=int,
help='Filter size of attention convolution')
group.add_argument('--cumulate-att-w', default=True, type=strtobool,
help="Whether or not to cumulate attention weights")
# decoder
group.add_argument('--dlayers', default=2, type=int,
help='Number of decoder layers')
group.add_argument('--dunits', default=1024, type=int,
help='Number of decoder hidden units')
group.add_argument('--prenet-layers', default=2, type=int,
help='Number of prenet layers')
group.add_argument('--prenet-units', default=256, type=int,
help='Number of prenet hidden units')
group.add_argument('--postnet-layers', default=5, type=int,
help='Number of postnet layers')
group.add_argument('--postnet-chans', default=512, type=int,
help='Number of postnet channels')
group.add_argument('--postnet-filts', default=5, type=int,
help='Filter size of postnet')
group.add_argument('--output-activation', default=None, type=str, nargs='?',
help='Output activation function')
# cbhg
group.add_argument('--use-cbhg', default=False, type=strtobool,
help='Whether to use CBHG module')
group.add_argument('--cbhg-conv-bank-layers', default=8, type=int,
help='Number of convoluional bank layers in CBHG')
group.add_argument('--cbhg-conv-bank-chans', default=128, type=int,
help='Number of convoluional bank channles in CBHG')
group.add_argument('--cbhg-conv-proj-filts', default=3, type=int,
help='Filter size of convoluional projection layer in CBHG')
group.add_argument('--cbhg-conv-proj-chans', default=256, type=int,
help='Number of convoluional projection channels in CBHG')
group.add_argument('--cbhg-highway-layers', default=4, type=int,
help='Number of highway layers in CBHG')
group.add_argument('--cbhg-highway-units', default=128, type=int,
help='Number of highway units in CBHG')
group.add_argument('--cbhg-gru-units', default=256, type=int,
help='Number of GRU units in CBHG')
# model (parameter) related
group.add_argument('--use-batch-norm', default=True, type=strtobool,
help='Whether to use batch normalization')
group.add_argument('--use-concate', default=True, type=strtobool,
help='Whether to concatenate encoder embedding with decoder outputs')
group.add_argument('--use-residual', default=True, type=strtobool,
help='Whether to use residual connection in conv layer')
group.add_argument('--dropout-rate', default=0.5, type=float,
help='Dropout rate')
group.add_argument('--zoneout-rate', default=0.1, type=float,
help='Zoneout rate')
group.add_argument('--reduction-factor', default=1, type=int,
help='Reduction factor')
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spc-dim", default=None, type=int,
help="Number of spectrogram dimensions")
group.add_argument("--pretrained-model", default=None, type=str,
help="Pretrained model path")
# loss related
group.add_argument('--use-masking', default=False, type=strtobool,
help='Whether to use masking in calculation of loss')
group.add_argument('--use-weighted-masking', default=False, type=strtobool,
help='Whether to use weighted masking in calculation of loss')
group.add_argument('--bce-pos-weight', default=20.0, type=float,
help='Positive sample weight in BCE calculation (only for use-masking=True)')
group.add_argument("--use-guided-attn-loss", default=False, type=strtobool,
help="Whether to use guided attention loss")
group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float,
help="Sigma in guided attention loss")
group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float,
help="Lambda in guided attention loss")
return parser
calculate_all_attentions(self, xs, ilens, ys, spembs=None, keep_tensor=False, *args, **kwargs)
¶Calculate all of the attention weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
keep_tensor |
bool |
Whether to keep original tensor. |
False |
Returns:
Type | Description |
---|---|
Union[ndarray, Tensor] |
Batch of attention weights (B, Lmax, Tmax). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def calculate_all_attentions(self, xs, ilens, ys, spembs=None, keep_tensor=False, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
keep_tensor (bool, optional): Whether to keep original tensor.
Returns:
Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax).
"""
# check ilens type (should be list of int)
if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
ilens = list(map(int, ilens))
self.eval()
with torch.no_grad():
hs, hlens = self.enc(xs, ilens)
if self.spk_embed_dim is not None:
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = torch.cat([hs, spembs], dim=-1)
att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
self.train()
if keep_tensor:
return att_ws
else:
return att_ws.cpu().numpy()
forward(self, xs, ilens, ys, labels, olens, spembs=None, extras=None, *args, **kwargs)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
extras |
Tensor |
Batch of groundtruth spectrograms (B, Lmax, spc_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def forward(self, xs, ilens, ys, labels, olens, spembs=None, extras=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
max_in = max(ilens)
max_out = max(olens)
if max_in != xs.shape[1]:
xs = xs[:, :max_in]
if max_out != ys.shape[1]:
ys = ys[:, :max_out]
labels = labels[:, :max_out]
# calculate tacotron2 outputs
hs, hlens = self.enc(xs, ilens)
if self.spk_embed_dim is not None:
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = torch.cat([hs, spembs], dim=-1)
after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_out = max(olens)
ys = ys[:, :max_out]
labels = labels[:, :max_out]
labels[:, -1] = 1.0 # make sure at least one frame has 1
# caluculate taco2 loss
l1_loss, mse_loss, bce_loss = self.taco2_loss(
after_outs, before_outs, logits, ys, labels, olens)
loss = l1_loss + mse_loss + bce_loss
report_keys = [
{'l1_loss': l1_loss.item()},
{'mse_loss': mse_loss.item()},
{'bce_loss': bce_loss.item()},
]
# caluculate attention loss
if self.use_guided_attn_loss:
# NOTE(kan-bayashi): length of output for auto-regressive input will be changed when r > 1
if self.reduction_factor > 1:
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
olens_in = olens
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
loss = loss + attn_loss
report_keys += [
{'attn_loss': attn_loss.item()},
]
# caluculate cbhg loss
if self.use_cbhg:
# remove unnecessary padded part (for multi-gpus)
if max_out != extras.shape[1]:
extras = extras[:, :max_out]
# caluculate cbhg outputs & loss and report them
cbhg_outs, _ = self.cbhg(after_outs, olens)
cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, extras, olens)
loss = loss + cbhg_l1_loss + cbhg_mse_loss
report_keys += [
{'cbhg_l1_loss': cbhg_l1_loss.item()},
{'cbhg_mse_loss': cbhg_mse_loss.item()},
]
report_keys += [{'loss': loss.item()}]
self.reporter.report(report_keys)
return loss
inference(self, x, inference_args, spemb=None, *args, **kwargs)
¶Generate the sequence of features given the sequences of characters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Input sequence of characters (T,). |
required |
inference_args |
Namespace |
|
required |
spemb |
Tensor |
Speaker embedding vector (spk_embed_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace):
- threshold (float): Threshold in inference.
- minlenratio (float): Minimum length ratio in inference.
- maxlenratio (float): Maximum length ratio in inference.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
"""
# get options
threshold = inference_args.threshold
minlenratio = inference_args.minlenratio
maxlenratio = inference_args.maxlenratio
use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility
backward_window = inference_args.backward_window if use_att_constraint else 0
forward_window = inference_args.forward_window if use_att_constraint else 0
# inference
h = self.enc.inference(x)
if self.spk_embed_dim is not None:
spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1)
h = torch.cat([h, spemb], dim=-1)
outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio,
use_att_constraint=use_att_constraint,
backward_window=backward_window,
forward_window=forward_window)
if self.use_cbhg:
cbhg_outs = self.cbhg.inference(outs)
return cbhg_outs, probs, att_ws
else:
return outs, probs, att_ws
Tacotron2Loss (Module)
¶
Loss function module for Tacotron2.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
class Tacotron2Loss(torch.nn.Module):
"""Loss function module for Tacotron2."""
def __init__(self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0):
"""Initialize Tactoron2 loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
bce_pos_weight (float): Weight of positive sample of stop token.
"""
super(Tacotron2Loss, self).__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
self.bce_criterion = torch.nn.BCEWithLogitsLoss(reduction=reduction,
pos_weight=torch.tensor(bce_pos_weight))
# NOTE(kan-bayashi): register pre hook function for the compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
"""Calculate forward propagation.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
logits (Tensor): Batch of stop logits (B, Lmax).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
Tensor: Binary cross entropy loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
ys = ys.masked_select(masks)
after_outs = after_outs.masked_select(masks)
before_outs = before_outs.masked_select(masks)
labels = labels.masked_select(masks[:, :, 0])
logits = logits.masked_select(masks[:, :, 0])
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(before_outs, ys)
bce_loss = self.bce_criterion(logits, labels)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
out_weights = weights.div(ys.size(0) * ys.size(2))
logit_weights = weights.div(ys.size(0))
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
bce_loss = bce_loss.mul(logit_weights.squeeze(-1)).masked_select(masks.squeeze(-1)).sum()
return l1_loss, mse_loss, bce_loss
def _load_state_dict_pre_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Apply pre hook fucntion before loading state dict.
From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but
old models do not include it and as a result, it causes missing key error when
loading old model parameter. This function solve the issue by adding param in
state dict before loading as a pre hook function of the `load_state_dict` method.
"""
key = prefix + "bce_criterion.pos_weight"
if key not in state_dict:
state_dict[key] = self.bce_criterion.pos_weight
__init__(self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0)
special
¶Initialize Tactoron2 loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
use_masking |
bool |
Whether to apply masking for padded part in loss calculation. |
True |
use_weighted_masking |
bool |
Whether to apply weighted masking in loss calculation. |
False |
bce_pos_weight |
float |
Weight of positive sample of stop token. |
20.0 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def __init__(self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0):
"""Initialize Tactoron2 loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
bce_pos_weight (float): Weight of positive sample of stop token.
"""
super(Tacotron2Loss, self).__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
self.bce_criterion = torch.nn.BCEWithLogitsLoss(reduction=reduction,
pos_weight=torch.tensor(bce_pos_weight))
# NOTE(kan-bayashi): register pre hook function for the compatibility
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
forward(self, after_outs, before_outs, logits, ys, labels, olens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
after_outs |
Tensor |
Batch of outputs after postnets (B, Lmax, odim). |
required |
before_outs |
Tensor |
Batch of outputs before postnets (B, Lmax, odim). |
required |
logits |
Tensor |
Batch of stop logits (B, Lmax). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
labels |
LongTensor |
Batch of the sequences of stop token labels (B, Lmax). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_tacotron2.py
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
"""Calculate forward propagation.
Args:
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
logits (Tensor): Batch of stop logits (B, Lmax).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
olens (LongTensor): Batch of the lengths of each target (B,).
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
Tensor: Binary cross entropy loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
ys = ys.masked_select(masks)
after_outs = after_outs.masked_select(masks)
before_outs = before_outs.masked_select(masks)
labels = labels.masked_select(masks[:, :, 0])
logits = logits.masked_select(masks[:, :, 0])
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(before_outs, ys)
bce_loss = self.bce_criterion(logits, labels)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
out_weights = weights.div(ys.size(0) * ys.size(2))
logit_weights = weights.div(ys.size(0))
# apply weight
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
bce_loss = bce_loss.mul(logit_weights.squeeze(-1)).masked_select(masks.squeeze(-1)).sum()
return l1_loss, mse_loss, bce_loss
e2e_tts_transformer
¶
TTS-Transformer related modules.
GuidedMultiHeadAttentionLoss (GuidedAttentionLoss)
¶
Guided attention loss function module for multi head attention.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sigma |
float |
Standard deviation to control how close attention to a diagonal. |
required |
alpha |
float |
Scaling coefficient (lambda). |
required |
reset_always |
bool |
Whether to always reset masks. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
"""Guided attention loss function module for multi head attention.
Args:
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws (Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens (LongTensor): Batch of input lenghts (B,).
olens (LongTensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
if self.masks is None:
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = torch.mean(losses.masked_select(self.masks))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
forward(self, att_ws, ilens, olens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
att_ws |
Tensor |
Batch of multi head attention weights (B, H, T_max_out, T_max_in). |
required |
ilens |
LongTensor |
Batch of input lenghts (B,). |
required |
olens |
LongTensor |
Batch of output lenghts (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Guided attention loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws (Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens (LongTensor): Batch of input lenghts (B,).
olens (LongTensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
if self.masks is None:
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = torch.mean(losses.masked_select(self.masks))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
TTSPlot (PlotAttentionReport)
¶
Attention plot module for TTS-Transformer.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
class TTSPlot(PlotAttentionReport):
"""Attention plot module for TTS-Transformer."""
def plotfn(self, data, attn_dict, outdir, suffix="png", savefn=None):
"""Plot multi head attentions.
Args:
data (dict): Utts info from json file.
attn_dict (dict): Multi head attention dict.
Values should be numpy.ndarray (H, L, T)
outdir (str): Directory name to save figures.
suffix (str): Filename suffix including image type (e.g., png).
savefn (function): Function to save figures.
"""
import matplotlib.pyplot as plt
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.%s.%s" % (
outdir, data[idx][0], name, suffix)
if "fbank" in name:
fig = plt.Figure()
ax = fig.subplots(1, 1)
ax.imshow(att_w, aspect="auto")
ax.set_xlabel("frames")
ax.set_ylabel("fbank coeff")
fig.tight_layout()
else:
fig = _plot_and_save_attention(att_w, filename)
savefn(fig, filename)
plotfn(self, data, attn_dict, outdir, suffix='png', savefn=None)
¶Plot multi head attentions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
dict |
Utts info from json file. |
required |
attn_dict |
dict |
Multi head attention dict. Values should be numpy.ndarray (H, L, T) |
required |
outdir |
str |
Directory name to save figures. |
required |
suffix |
str |
Filename suffix including image type (e.g., png). |
'png' |
savefn |
function |
Function to save figures. |
None |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def plotfn(self, data, attn_dict, outdir, suffix="png", savefn=None):
"""Plot multi head attentions.
Args:
data (dict): Utts info from json file.
attn_dict (dict): Multi head attention dict.
Values should be numpy.ndarray (H, L, T)
outdir (str): Directory name to save figures.
suffix (str): Filename suffix including image type (e.g., png).
savefn (function): Function to save figures.
"""
import matplotlib.pyplot as plt
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.%s.%s" % (
outdir, data[idx][0], name, suffix)
if "fbank" in name:
fig = plt.Figure()
ax = fig.subplots(1, 1)
ax.imshow(att_w, aspect="auto")
ax.set_xlabel("frames")
ax.set_ylabel("fbank coeff")
fig.tight_layout()
else:
fig = _plot_and_save_attention(att_w, filename)
savefn(fig, filename)
Transformer (TTSInterface, Module)
¶
Text-to-Speech Transformer module.
This is a module of text-to-speech Transformer described in Neural Speech Synthesis with Transformer Network
_,
which convert the sequence of characters or phonemes into the sequence of Mel-filterbanks.
.. _Neural Speech Synthesis with Transformer Network
:
https://arxiv.org/pdf/1809.08895.pdf
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
class Transformer(TTSInterface, torch.nn.Module):
"""Text-to-Speech Transformer module.
This is a module of text-to-speech Transformer described in `Neural Speech Synthesis with Transformer Network`_,
which convert the sequence of characters or phonemes into the sequence of Mel-filterbanks.
.. _`Neural Speech Synthesis with Transformer Network`:
https://arxiv.org/pdf/1809.08895.pdf
"""
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("transformer model setting")
# network structure related
group.add_argument("--embed-dim", default=512, type=int,
help="Dimension of character embedding in encoder prenet")
group.add_argument("--eprenet-conv-layers", default=3, type=int,
help="Number of encoder prenet convolution layers")
group.add_argument("--eprenet-conv-chans", default=256, type=int,
help="Number of encoder prenet convolution channels")
group.add_argument("--eprenet-conv-filts", default=5, type=int,
help="Filter size of encoder prenet convolution")
group.add_argument("--dprenet-layers", default=2, type=int,
help="Number of decoder prenet layers")
group.add_argument("--dprenet-units", default=256, type=int,
help="Number of decoder prenet hidden units")
group.add_argument("--elayers", default=3, type=int,
help="Number of encoder layers")
group.add_argument("--eunits", default=1536, type=int,
help="Number of encoder hidden units")
group.add_argument("--adim", default=384, type=int,
help="Number of attention transformation dimensions")
group.add_argument("--aheads", default=4, type=int,
help="Number of heads for multi head attention")
group.add_argument("--dlayers", default=3, type=int,
help="Number of decoder layers")
group.add_argument("--dunits", default=1536, type=int,
help="Number of decoder hidden units")
group.add_argument("--positionwise-layer-type", default="linear", type=str,
choices=["linear", "conv1d", "conv1d-linear"],
help="Positionwise layer type.")
group.add_argument("--positionwise-conv-kernel-size", default=1, type=int,
help="Kernel size of positionwise conv1d layer")
group.add_argument("--postnet-layers", default=5, type=int,
help="Number of postnet layers")
group.add_argument("--postnet-chans", default=256, type=int,
help="Number of postnet channels")
group.add_argument("--postnet-filts", default=5, type=int,
help="Filter size of postnet")
group.add_argument("--use-scaled-pos-enc", default=True, type=strtobool,
help="Use trainable scaled positional encoding instead of the fixed scale one.")
group.add_argument("--use-batch-norm", default=True, type=strtobool,
help="Whether to use batch normalization")
group.add_argument("--encoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before encoder block")
group.add_argument("--decoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before decoder block")
group.add_argument("--encoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in encoder")
group.add_argument("--decoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in decoder")
group.add_argument("--reduction-factor", default=1, type=int,
help="Reduction factor")
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spk-embed-integration-type", type=str, default="add",
choices=["add", "concat"],
help="How to integrate speaker embedding")
# training related
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help="How to initialize transformer parameters")
group.add_argument("--initial-encoder-alpha", type=float, default=1.0,
help="Initial alpha value in encoder's ScaledPositionalEncoding")
group.add_argument("--initial-decoder-alpha", type=float, default=1.0,
help="Initial alpha value in decoder's ScaledPositionalEncoding")
group.add_argument("--transformer-lr", default=1.0, type=float,
help="Initial value of learning rate")
group.add_argument("--transformer-warmup-steps", default=4000, type=int,
help="Optimizer warmup steps")
group.add_argument("--transformer-enc-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder except for attention")
group.add_argument("--transformer-enc-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder positional encoding")
group.add_argument("--transformer-enc-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder self-attention")
group.add_argument("--transformer-dec-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder except for attention and pos encoding")
group.add_argument("--transformer-dec-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder positional encoding")
group.add_argument("--transformer-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder self-attention")
group.add_argument("--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder-decoder attention")
group.add_argument("--eprenet-dropout-rate", default=0.5, type=float,
help="Dropout rate in encoder prenet")
group.add_argument("--dprenet-dropout-rate", default=0.5, type=float,
help="Dropout rate in decoder prenet")
group.add_argument("--postnet-dropout-rate", default=0.5, type=float,
help="Dropout rate in postnet")
group.add_argument("--pretrained-model", default=None, type=str,
help="Pretrained model path")
# loss related
group.add_argument("--use-masking", default=True, type=strtobool,
help="Whether to use masking in calculation of loss")
group.add_argument("--use-weighted-masking", default=False, type=strtobool,
help="Whether to use weighted masking in calculation of loss")
group.add_argument("--loss-type", default="L1", choices=["L1", "L2", "L1+L2"],
help="How to calc loss")
group.add_argument("--bce-pos-weight", default=5.0, type=float,
help="Positive sample weight in BCE calculation (only for use-masking=True)")
group.add_argument("--use-guided-attn-loss", default=False, type=strtobool,
help="Whether to use guided attention loss")
group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float,
help="Sigma in guided attention loss")
group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float,
help="Lambda in guided attention loss")
group.add_argument("--num-heads-applied-guided-attn", default=2, type=int,
help="Number of heads in each layer to be applied guided attention loss"
"if set -1, all of the heads will be applied.")
group.add_argument("--num-layers-applied-guided-attn", default=2, type=int,
help="Number of layers to be applied guided attention loss"
"if set -1, all of the layers will be applied.")
group.add_argument("--modules-applied-guided-attn", type=str, nargs="+",
default=["encoder-decoder"],
help="Module name list to be applied guided attention loss")
return parser
@property
def attention_plot_class(self):
"""Return plot class for attention weight plot."""
return TTSPlot
def __init__(self, idim, odim, args=None):
"""Initialize TTS-Transformer module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- embed_dim (int): Dimension of character embedding.
- eprenet_conv_layers (int): Number of encoder prenet convolution layers.
- eprenet_conv_chans (int): Number of encoder prenet convolution channels.
- eprenet_conv_filts (int): Filter size of encoder prenet convolution.
- dprenet_layers (int): Number of decoder prenet layers.
- dprenet_units (int): Number of decoder prenet hidden units.
- elayers (int): Number of encoder layers.
- eunits (int): Number of encoder hidden units.
- adim (int): Number of attention transformation dimensions.
- aheads (int): Number of heads for multi head attention.
- dlayers (int): Number of decoder layers.
- dunits (int): Number of decoder hidden units.
- postnet_layers (int): Number of postnet layers.
- postnet_chans (int): Number of postnet channels.
- postnet_filts (int): Filter size of postnet.
- use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
- use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
- encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
- decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
- encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
- decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
- reduction_factor (int): Reduction factor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spk_embed_integration_type: How to integrate speaker embedding.
- transformer_init (float): How to initialize transformer parameters.
- transformer_lr (float): Initial value of learning rate.
- transformer_warmup_steps (int): Optimizer warmup steps.
- transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
- transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
- transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
- transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
- transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
- transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
- transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
- eprenet_dropout_rate (float): Dropout rate in encoder prenet.
- dprenet_dropout_rate (float): Dropout rate in decoder prenet.
- postnet_dropout_rate (float): Dropout rate in postnet.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true).
- loss_type (str): How to calculate loss.
- use_guided_attn_loss (bool): Whether to use guided attention loss.
- num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss.
- num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss.
- modules_applied_guided_attn (list): List of module names to apply guided attention loss.
- guided-attn-loss-sigma (float) Sigma in guided attention loss.
- guided-attn-loss-lambda (float): Lambda in guided attention loss.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.spk_embed_dim = args.spk_embed_dim
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = args.spk_embed_integration_type
self.use_scaled_pos_enc = args.use_scaled_pos_enc
self.reduction_factor = args.reduction_factor
self.loss_type = args.loss_type
self.use_guided_attn_loss = args.use_guided_attn_loss
if self.use_guided_attn_loss:
if args.num_layers_applied_guided_attn == -1:
self.num_layers_applied_guided_attn = args.elayers
else:
self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
if args.num_heads_applied_guided_attn == -1:
self.num_heads_applied_guided_attn = args.aheads
else:
self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
self.modules_applied_guided_attn = args.modules_applied_guided_attn
# use idx 0 as padding idx
padding_idx = 0
# get positional encoding class
pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
# define transformer encoder
if args.eprenet_conv_layers != 0:
# encoder prenet
encoder_input_layer = torch.nn.Sequential(
EncoderPrenet(
idim=idim,
embed_dim=args.embed_dim,
elayers=0,
econv_layers=args.eprenet_conv_layers,
econv_chans=args.eprenet_conv_chans,
econv_filts=args.eprenet_conv_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.eprenet_dropout_rate,
padding_idx=padding_idx
),
torch.nn.Linear(args.eprenet_conv_chans, args.adim)
)
else:
encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim,
embedding_dim=args.adim,
padding_idx=padding_idx
)
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=encoder_input_layer,
dropout_rate=args.transformer_enc_dropout_rate,
positional_dropout_rate=args.transformer_enc_positional_dropout_rate,
attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.encoder_normalize_before,
concat_after=args.encoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
)
# define projection layer
if self.spk_embed_dim is not None:
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim)
else:
self.projection = torch.nn.Linear(args.adim + self.spk_embed_dim, args.adim)
# define transformer decoder
if args.dprenet_layers != 0:
# decoder prenet
decoder_input_layer = torch.nn.Sequential(
DecoderPrenet(
idim=odim,
n_layers=args.dprenet_layers,
n_units=args.dprenet_units,
dropout_rate=args.dprenet_dropout_rate
),
torch.nn.Linear(args.dprenet_units, args.adim)
)
else:
decoder_input_layer = "linear"
self.decoder = Decoder(
odim=-1,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
dropout_rate=args.transformer_dec_dropout_rate,
positional_dropout_rate=args.transformer_dec_positional_dropout_rate,
self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
src_attention_dropout_rate=args.transformer_enc_dec_attn_dropout_rate,
input_layer=decoder_input_layer,
use_output_layer=False,
pos_enc_class=pos_enc_class,
normalize_before=args.decoder_normalize_before,
concat_after=args.decoder_concat_after
)
# define final projection
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor)
self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)
# define postnet
self.postnet = None if args.postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=args.postnet_layers,
n_chans=args.postnet_chans,
n_filts=args.postnet_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.postnet_dropout_rate
)
# define loss function
self.criterion = TransformerLoss(use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking,
bce_pos_weight=args.bce_pos_weight)
if self.use_guided_attn_loss:
self.attn_criterion = GuidedMultiHeadAttentionLoss(
sigma=args.guided_attn_loss_sigma,
alpha=args.guided_attn_loss_lambda,
)
# initialize parameters
self._reset_parameters(init_type=args.transformer_init,
init_enc_alpha=args.initial_encoder_alpha,
init_dec_alpha=args.initial_decoder_alpha)
# load pretrained model
if args.pretrained_model is not None:
self.load_pretrained_model(args.pretrained_model)
def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0):
# initialize parameters
initialize(self, init_type)
# initialize alpha in scaled positional encoding
if self.use_scaled_pos_enc:
self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
def _add_first_frame_and_remove_last_frame(self, ys):
ys_in = torch.cat([ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1)
return ys_in
def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
max_ilen = max(ilens)
max_olen = max(olens)
if max_ilen != xs.shape[1]:
xs = xs[:, :max_ilen]
if max_olen != ys.shape[1]:
ys = ys[:, :max_olen]
labels = labels[:, :max_olen]
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1:
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
ys_in, olens_in = ys, olens
# add first zero frame and remove last frame for auto-regressive
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)
# forward decoder
y_masks = self._target_mask(olens_in)
xy_masks = self._source_to_target_mask(ilens, olens_in)
zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
# (B, Lmax//r, r) -> (B, Lmax//r * r)
logits = self.prob_out(zs).view(zs.size(0), -1)
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
labels = labels[:, :max_olen]
labels[:, -1] = 1.0 # make sure at least one frame has 1
# caluculate loss values
l1_loss, l2_loss, bce_loss = self.criterion(
after_outs, before_outs, logits, ys, labels, olens)
if self.loss_type == "L1":
loss = l1_loss + bce_loss
elif self.loss_type == "L2":
loss = l2_loss + bce_loss
elif self.loss_type == "L1+L2":
loss = l1_loss + l2_loss + bce_loss
else:
raise ValueError("unknown --loss-type " + self.loss_type)
report_keys = [
{"l1_loss": l1_loss.item()},
{"l2_loss": l2_loss.item()},
{"bce_loss": bce_loss.item()},
{"loss": loss.item()},
]
# calculate guided attention loss
if self.use_guided_attn_loss:
# calculate for encoder
if "encoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.encoder.encoders)))):
att_ws += [self.encoder.encoders[layer_idx].self_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in)
enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
loss = loss + enc_attn_loss
report_keys += [{"enc_attn_loss": enc_attn_loss.item()}]
# calculate for decoder
if "decoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.decoder.decoders)))):
att_ws += [self.decoder.decoders[layer_idx].self_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out)
dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
loss = loss + dec_attn_loss
report_keys += [{"dec_attn_loss": dec_attn_loss.item()}]
# calculate for encoder-decoder
if "encoder-decoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.decoder.decoders)))):
att_ws += [self.decoder.decoders[layer_idx].src_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in)
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in)
loss = loss + enc_dec_attn_loss
report_keys += [{"enc_dec_attn_loss": enc_dec_attn_loss.item()}]
# report extra information
if self.use_scaled_pos_enc:
report_keys += [
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()},
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()},
]
self.reporter.report(report_keys)
return loss
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace):
- threshold (float): Threshold in inference.
- minlenratio (float): Minimum length ratio in inference.
- maxlenratio (float): Maximum length ratio in inference.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T).
"""
# get options
threshold = inference_args.threshold
minlenratio = inference_args.minlenratio
maxlenratio = inference_args.maxlenratio
use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility
if use_att_constraint:
logging.warning("Attention constraint is not yet supported in Transformer. Not enabled.")
# forward encoder
xs = x.unsqueeze(0)
hs, _ = self.encoder(xs, None)
# integrate speaker embedding
if self.spk_embed_dim is not None:
spembs = spemb.unsqueeze(0)
hs = self._integrate_with_spk_embed(hs, spembs)
# set limits of length
maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor)
minlen = int(hs.size(1) * minlenratio / self.reduction_factor)
# initialize
idx = 0
ys = hs.new_zeros(1, 1, self.odim)
outs, probs = [], []
# forward decoder step-by-step
z_cache = self.decoder.init_state(x)
while True:
# update index
idx += 1
# calculate output and stop prob at idx-th step
y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device)
z, z_cache = self.decoder.forward_one_step(ys, y_masks, hs, cache=z_cache) # (B, adim)
outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...]
probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...]
# update next inputs
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim)
# get attention weights
att_ws_ = []
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention) and "src" in name:
att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...]
if idx == 1:
att_ws = att_ws_
else:
# [(#heads, l, T), ...]
att_ws = [torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_)]
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
break
# concatenate attention weights -> (#layers, #heads, L, T)
att_ws = torch.stack(att_ws, dim=0)
return outs, probs, att_ws
def calculate_all_attentions(self, xs, ilens, ys, olens,
spembs=None, skip_output=False, keep_tensor=False, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
skip_output (bool, optional): Whether to skip calculate the final output.
keep_tensor (bool, optional): Whether to keep original tensor.
Returns:
dict: Dict of attention weights and outputs.
"""
with torch.no_grad():
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1:
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
ys_in, olens_in = ys, olens
# add first zero frame and remove last frame for auto-regressive
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)
# forward decoder
y_masks = self._target_mask(olens_in)
xy_masks = self._source_to_target_mask(ilens, olens_in)
zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
# calculate final outputs
if not skip_output:
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
# modifiy mod part of output lengths due to reduction factor > 1
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
# store into dict
att_ws_dict = dict()
if keep_tensor:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
att_ws_dict[name] = m.attn
if not skip_output:
att_ws_dict["before_postnet_fbank"] = before_outs
att_ws_dict["after_postnet_fbank"] = after_outs
else:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens_in.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens_in.tolist())]
else:
logging.warning("unknown attention module: " + name)
else:
logging.warning("unknown attention module: " + name)
att_ws_dict[name] = attn
if not skip_output:
before_outs = before_outs.cpu().numpy()
after_outs = after_outs.cpu().numpy()
att_ws_dict["before_postnet_fbank"] = [m[:l].T for m, l in zip(before_outs, olens.tolist())]
att_ws_dict["after_postnet_fbank"] = [m[:l].T for m, l in zip(after_outs, olens.tolist())]
return att_ws_dict
def _integrate_with_spk_embed(self, hs, spembs):
"""Integrate speaker embedding with hidden states.
Args:
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
Returns:
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
"""
if self.spk_embed_integration_type == "add":
# apply projection and then add to hidden states
spembs = self.projection(F.normalize(spembs))
hs = hs + spembs.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
hs = self.projection(torch.cat([hs, spembs], dim=-1))
else:
raise NotImplementedError("support only add or concat.")
return hs
def _source_mask(self, ilens):
"""Make masks for self-attention.
Args:
ilens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [5, 3]
>>> self._source_mask(ilens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
def _target_mask(self, olens):
"""Make masks for masked self-attention.
Args:
olens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for masked self-attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> olens = [5, 3]
>>> self._target_mask(olens)
tensor([[[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1]],
[[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0)
return y_masks.unsqueeze(-2) & s_masks & y_masks.unsqueeze(-1)
def _source_to_target_mask(self, ilens, olens):
"""Make masks for encoder-decoder attention.
Args:
ilens (LongTensor or List): Batch of lengths (B,).
olens (LongTensor or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor for encoder-decoder attention.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
>>> ilens = [4, 2]
>>> olens = [5, 3]
>>> self._source_to_target_mask(ilens)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]], dtype=torch.uint8)
"""
x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device)
y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device)
return x_masks.unsqueeze(-2) & y_masks.unsqueeze(-1)
@property
def base_plot_keys(self):
"""Return base key names to plot during training. keys should match what `chainer.reporter` reports.
If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.
Returns:
list: List of strings which are base keys to plot during training.
"""
plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"]
if self.use_scaled_pos_enc:
plot_keys += ["encoder_alpha", "decoder_alpha"]
if self.use_guided_attn_loss:
if "encoder" in self.modules_applied_guided_attn:
plot_keys += ["enc_attn_loss"]
if "decoder" in self.modules_applied_guided_attn:
plot_keys += ["dec_attn_loss"]
if "encoder-decoder" in self.modules_applied_guided_attn:
plot_keys += ["enc_dec_attn_loss"]
return plot_keys
attention_plot_class
property
readonly
¶Return plot class for attention weight plot.
base_plot_keys
property
readonly
¶Return base key names to plot during training. keys should match what chainer.reporter
reports.
If you add the key loss
, the reporter will report main/loss
and validation/main/loss
values.
also loss.png
will be created as a figure visulizing main/loss
and validation/main/loss
values.
Returns:
Type | Description |
---|---|
list |
List of strings which are base keys to plot during training. |
__init__(self, idim, odim, args=None)
special
¶Initialize TTS-Transformer module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
args |
Namespace |
|
None |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def __init__(self, idim, odim, args=None):
"""Initialize TTS-Transformer module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
args (Namespace, optional):
- embed_dim (int): Dimension of character embedding.
- eprenet_conv_layers (int): Number of encoder prenet convolution layers.
- eprenet_conv_chans (int): Number of encoder prenet convolution channels.
- eprenet_conv_filts (int): Filter size of encoder prenet convolution.
- dprenet_layers (int): Number of decoder prenet layers.
- dprenet_units (int): Number of decoder prenet hidden units.
- elayers (int): Number of encoder layers.
- eunits (int): Number of encoder hidden units.
- adim (int): Number of attention transformation dimensions.
- aheads (int): Number of heads for multi head attention.
- dlayers (int): Number of decoder layers.
- dunits (int): Number of decoder hidden units.
- postnet_layers (int): Number of postnet layers.
- postnet_chans (int): Number of postnet channels.
- postnet_filts (int): Filter size of postnet.
- use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
- use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
- encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
- decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
- encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
- decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
- reduction_factor (int): Reduction factor.
- spk_embed_dim (int): Number of speaker embedding dimenstions.
- spk_embed_integration_type: How to integrate speaker embedding.
- transformer_init (float): How to initialize transformer parameters.
- transformer_lr (float): Initial value of learning rate.
- transformer_warmup_steps (int): Optimizer warmup steps.
- transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
- transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
- transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
- transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
- transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
- transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
- transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
- eprenet_dropout_rate (float): Dropout rate in encoder prenet.
- dprenet_dropout_rate (float): Dropout rate in decoder prenet.
- postnet_dropout_rate (float): Dropout rate in postnet.
- use_masking (bool): Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
- bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true).
- loss_type (str): How to calculate loss.
- use_guided_attn_loss (bool): Whether to use guided attention loss.
- num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss.
- num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss.
- modules_applied_guided_attn (list): List of module names to apply guided attention loss.
- guided-attn-loss-sigma (float) Sigma in guided attention loss.
- guided-attn-loss-lambda (float): Lambda in guided attention loss.
"""
# initialize base classes
TTSInterface.__init__(self)
torch.nn.Module.__init__(self)
# fill missing arguments
args = fill_missing_args(args, self.add_arguments)
# store hyperparameters
self.idim = idim
self.odim = odim
self.spk_embed_dim = args.spk_embed_dim
if self.spk_embed_dim is not None:
self.spk_embed_integration_type = args.spk_embed_integration_type
self.use_scaled_pos_enc = args.use_scaled_pos_enc
self.reduction_factor = args.reduction_factor
self.loss_type = args.loss_type
self.use_guided_attn_loss = args.use_guided_attn_loss
if self.use_guided_attn_loss:
if args.num_layers_applied_guided_attn == -1:
self.num_layers_applied_guided_attn = args.elayers
else:
self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
if args.num_heads_applied_guided_attn == -1:
self.num_heads_applied_guided_attn = args.aheads
else:
self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
self.modules_applied_guided_attn = args.modules_applied_guided_attn
# use idx 0 as padding idx
padding_idx = 0
# get positional encoding class
pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
# define transformer encoder
if args.eprenet_conv_layers != 0:
# encoder prenet
encoder_input_layer = torch.nn.Sequential(
EncoderPrenet(
idim=idim,
embed_dim=args.embed_dim,
elayers=0,
econv_layers=args.eprenet_conv_layers,
econv_chans=args.eprenet_conv_chans,
econv_filts=args.eprenet_conv_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.eprenet_dropout_rate,
padding_idx=padding_idx
),
torch.nn.Linear(args.eprenet_conv_chans, args.adim)
)
else:
encoder_input_layer = torch.nn.Embedding(
num_embeddings=idim,
embedding_dim=args.adim,
padding_idx=padding_idx
)
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=encoder_input_layer,
dropout_rate=args.transformer_enc_dropout_rate,
positional_dropout_rate=args.transformer_enc_positional_dropout_rate,
attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
pos_enc_class=pos_enc_class,
normalize_before=args.encoder_normalize_before,
concat_after=args.encoder_concat_after,
positionwise_layer_type=args.positionwise_layer_type,
positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
)
# define projection layer
if self.spk_embed_dim is not None:
if self.spk_embed_integration_type == "add":
self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim)
else:
self.projection = torch.nn.Linear(args.adim + self.spk_embed_dim, args.adim)
# define transformer decoder
if args.dprenet_layers != 0:
# decoder prenet
decoder_input_layer = torch.nn.Sequential(
DecoderPrenet(
idim=odim,
n_layers=args.dprenet_layers,
n_units=args.dprenet_units,
dropout_rate=args.dprenet_dropout_rate
),
torch.nn.Linear(args.dprenet_units, args.adim)
)
else:
decoder_input_layer = "linear"
self.decoder = Decoder(
odim=-1,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.dunits,
num_blocks=args.dlayers,
dropout_rate=args.transformer_dec_dropout_rate,
positional_dropout_rate=args.transformer_dec_positional_dropout_rate,
self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
src_attention_dropout_rate=args.transformer_enc_dec_attn_dropout_rate,
input_layer=decoder_input_layer,
use_output_layer=False,
pos_enc_class=pos_enc_class,
normalize_before=args.decoder_normalize_before,
concat_after=args.decoder_concat_after
)
# define final projection
self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor)
self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)
# define postnet
self.postnet = None if args.postnet_layers == 0 else Postnet(
idim=idim,
odim=odim,
n_layers=args.postnet_layers,
n_chans=args.postnet_chans,
n_filts=args.postnet_filts,
use_batch_norm=args.use_batch_norm,
dropout_rate=args.postnet_dropout_rate
)
# define loss function
self.criterion = TransformerLoss(use_masking=args.use_masking,
use_weighted_masking=args.use_weighted_masking,
bce_pos_weight=args.bce_pos_weight)
if self.use_guided_attn_loss:
self.attn_criterion = GuidedMultiHeadAttentionLoss(
sigma=args.guided_attn_loss_sigma,
alpha=args.guided_attn_loss_lambda,
)
# initialize parameters
self._reset_parameters(init_type=args.transformer_init,
init_enc_alpha=args.initial_encoder_alpha,
init_dec_alpha=args.initial_decoder_alpha)
# load pretrained model
if args.pretrained_model is not None:
self.load_pretrained_model(args.pretrained_model)
add_arguments(parser)
staticmethod
¶Add model-specific arguments to the parser.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
@staticmethod
def add_arguments(parser):
"""Add model-specific arguments to the parser."""
group = parser.add_argument_group("transformer model setting")
# network structure related
group.add_argument("--embed-dim", default=512, type=int,
help="Dimension of character embedding in encoder prenet")
group.add_argument("--eprenet-conv-layers", default=3, type=int,
help="Number of encoder prenet convolution layers")
group.add_argument("--eprenet-conv-chans", default=256, type=int,
help="Number of encoder prenet convolution channels")
group.add_argument("--eprenet-conv-filts", default=5, type=int,
help="Filter size of encoder prenet convolution")
group.add_argument("--dprenet-layers", default=2, type=int,
help="Number of decoder prenet layers")
group.add_argument("--dprenet-units", default=256, type=int,
help="Number of decoder prenet hidden units")
group.add_argument("--elayers", default=3, type=int,
help="Number of encoder layers")
group.add_argument("--eunits", default=1536, type=int,
help="Number of encoder hidden units")
group.add_argument("--adim", default=384, type=int,
help="Number of attention transformation dimensions")
group.add_argument("--aheads", default=4, type=int,
help="Number of heads for multi head attention")
group.add_argument("--dlayers", default=3, type=int,
help="Number of decoder layers")
group.add_argument("--dunits", default=1536, type=int,
help="Number of decoder hidden units")
group.add_argument("--positionwise-layer-type", default="linear", type=str,
choices=["linear", "conv1d", "conv1d-linear"],
help="Positionwise layer type.")
group.add_argument("--positionwise-conv-kernel-size", default=1, type=int,
help="Kernel size of positionwise conv1d layer")
group.add_argument("--postnet-layers", default=5, type=int,
help="Number of postnet layers")
group.add_argument("--postnet-chans", default=256, type=int,
help="Number of postnet channels")
group.add_argument("--postnet-filts", default=5, type=int,
help="Filter size of postnet")
group.add_argument("--use-scaled-pos-enc", default=True, type=strtobool,
help="Use trainable scaled positional encoding instead of the fixed scale one.")
group.add_argument("--use-batch-norm", default=True, type=strtobool,
help="Whether to use batch normalization")
group.add_argument("--encoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before encoder block")
group.add_argument("--decoder-normalize-before", default=False, type=strtobool,
help="Whether to apply layer norm before decoder block")
group.add_argument("--encoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in encoder")
group.add_argument("--decoder-concat-after", default=False, type=strtobool,
help="Whether to concatenate attention layer's input and output in decoder")
group.add_argument("--reduction-factor", default=1, type=int,
help="Reduction factor")
group.add_argument("--spk-embed-dim", default=None, type=int,
help="Number of speaker embedding dimensions")
group.add_argument("--spk-embed-integration-type", type=str, default="add",
choices=["add", "concat"],
help="How to integrate speaker embedding")
# training related
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
help="How to initialize transformer parameters")
group.add_argument("--initial-encoder-alpha", type=float, default=1.0,
help="Initial alpha value in encoder's ScaledPositionalEncoding")
group.add_argument("--initial-decoder-alpha", type=float, default=1.0,
help="Initial alpha value in decoder's ScaledPositionalEncoding")
group.add_argument("--transformer-lr", default=1.0, type=float,
help="Initial value of learning rate")
group.add_argument("--transformer-warmup-steps", default=4000, type=int,
help="Optimizer warmup steps")
group.add_argument("--transformer-enc-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder except for attention")
group.add_argument("--transformer-enc-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder positional encoding")
group.add_argument("--transformer-enc-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder self-attention")
group.add_argument("--transformer-dec-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder except for attention and pos encoding")
group.add_argument("--transformer-dec-positional-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder positional encoding")
group.add_argument("--transformer-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer decoder self-attention")
group.add_argument("--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float,
help="Dropout rate for transformer encoder-decoder attention")
group.add_argument("--eprenet-dropout-rate", default=0.5, type=float,
help="Dropout rate in encoder prenet")
group.add_argument("--dprenet-dropout-rate", default=0.5, type=float,
help="Dropout rate in decoder prenet")
group.add_argument("--postnet-dropout-rate", default=0.5, type=float,
help="Dropout rate in postnet")
group.add_argument("--pretrained-model", default=None, type=str,
help="Pretrained model path")
# loss related
group.add_argument("--use-masking", default=True, type=strtobool,
help="Whether to use masking in calculation of loss")
group.add_argument("--use-weighted-masking", default=False, type=strtobool,
help="Whether to use weighted masking in calculation of loss")
group.add_argument("--loss-type", default="L1", choices=["L1", "L2", "L1+L2"],
help="How to calc loss")
group.add_argument("--bce-pos-weight", default=5.0, type=float,
help="Positive sample weight in BCE calculation (only for use-masking=True)")
group.add_argument("--use-guided-attn-loss", default=False, type=strtobool,
help="Whether to use guided attention loss")
group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float,
help="Sigma in guided attention loss")
group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float,
help="Lambda in guided attention loss")
group.add_argument("--num-heads-applied-guided-attn", default=2, type=int,
help="Number of heads in each layer to be applied guided attention loss"
"if set -1, all of the heads will be applied.")
group.add_argument("--num-layers-applied-guided-attn", default=2, type=int,
help="Number of layers to be applied guided attention loss"
"if set -1, all of the layers will be applied.")
group.add_argument("--modules-applied-guided-attn", type=str, nargs="+",
default=["encoder-decoder"],
help="Module name list to be applied guided attention loss")
return parser
calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, skip_output=False, keep_tensor=False, *args, **kwargs)
¶Calculate all of the attention weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
skip_output |
bool |
Whether to skip calculate the final output. |
False |
keep_tensor |
bool |
Whether to keep original tensor. |
False |
Returns:
Type | Description |
---|---|
dict |
Dict of attention weights and outputs. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def calculate_all_attentions(self, xs, ilens, ys, olens,
spembs=None, skip_output=False, keep_tensor=False, *args, **kwargs):
"""Calculate all of the attention weights.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
skip_output (bool, optional): Whether to skip calculate the final output.
keep_tensor (bool, optional): Whether to keep original tensor.
Returns:
dict: Dict of attention weights and outputs.
"""
with torch.no_grad():
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1:
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
ys_in, olens_in = ys, olens
# add first zero frame and remove last frame for auto-regressive
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)
# forward decoder
y_masks = self._target_mask(olens_in)
xy_masks = self._source_to_target_mask(ilens, olens_in)
zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
# calculate final outputs
if not skip_output:
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
# modifiy mod part of output lengths due to reduction factor > 1
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
# store into dict
att_ws_dict = dict()
if keep_tensor:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
att_ws_dict[name] = m.attn
if not skip_output:
att_ws_dict["before_postnet_fbank"] = before_outs
att_ws_dict["after_postnet_fbank"] = after_outs
else:
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention):
attn = m.attn.cpu().numpy()
if "encoder" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())]
elif "decoder" in name:
if "src" in name:
attn = [a[:, :ol, :il] for a, il, ol in zip(attn, ilens.tolist(), olens_in.tolist())]
elif "self" in name:
attn = [a[:, :l, :l] for a, l in zip(attn, olens_in.tolist())]
else:
logging.warning("unknown attention module: " + name)
else:
logging.warning("unknown attention module: " + name)
att_ws_dict[name] = attn
if not skip_output:
before_outs = before_outs.cpu().numpy()
after_outs = after_outs.cpu().numpy()
att_ws_dict["before_postnet_fbank"] = [m[:l].T for m, l in zip(before_outs, olens.tolist())]
att_ws_dict["after_postnet_fbank"] = [m[:l].T for m, l in zip(after_outs, olens.tolist())]
return att_ws_dict
forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of padded character ids (B, Tmax). |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of padded target features (B, Lmax, odim). |
required |
olens |
LongTensor |
Batch of the lengths of each target (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of padded character ids (B, Tmax).
ilens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of padded target features (B, Lmax, odim).
olens (LongTensor): Batch of the lengths of each target (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
Returns:
Tensor: Loss value.
"""
# remove unnecessary padded part (for multi-gpus)
max_ilen = max(ilens)
max_olen = max(olens)
if max_ilen != xs.shape[1]:
xs = xs[:, :max_ilen]
if max_olen != ys.shape[1]:
ys = ys[:, :max_olen]
labels = labels[:, :max_olen]
# forward encoder
x_masks = self._source_mask(ilens)
hs, _ = self.encoder(xs, x_masks)
# integrate speaker embedding
if self.spk_embed_dim is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
if self.reduction_factor > 1:
ys_in = ys[:, self.reduction_factor - 1::self.reduction_factor]
olens_in = olens.new([olen // self.reduction_factor for olen in olens])
else:
ys_in, olens_in = ys, olens
# add first zero frame and remove last frame for auto-regressive
ys_in = self._add_first_frame_and_remove_last_frame(ys_in)
# forward decoder
y_masks = self._target_mask(olens_in)
xy_masks = self._source_to_target_mask(ilens, olens_in)
zs, _ = self.decoder(ys_in, y_masks, hs, xy_masks)
# (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim)
before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim)
# (B, Lmax//r, r) -> (B, Lmax//r * r)
logits = self.prob_out(zs).view(zs.size(0), -1)
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2)
# modifiy mod part of groundtruth
if self.reduction_factor > 1:
olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
max_olen = max(olens)
ys = ys[:, :max_olen]
labels = labels[:, :max_olen]
labels[:, -1] = 1.0 # make sure at least one frame has 1
# caluculate loss values
l1_loss, l2_loss, bce_loss = self.criterion(
after_outs, before_outs, logits, ys, labels, olens)
if self.loss_type == "L1":
loss = l1_loss + bce_loss
elif self.loss_type == "L2":
loss = l2_loss + bce_loss
elif self.loss_type == "L1+L2":
loss = l1_loss + l2_loss + bce_loss
else:
raise ValueError("unknown --loss-type " + self.loss_type)
report_keys = [
{"l1_loss": l1_loss.item()},
{"l2_loss": l2_loss.item()},
{"bce_loss": bce_loss.item()},
{"loss": loss.item()},
]
# calculate guided attention loss
if self.use_guided_attn_loss:
# calculate for encoder
if "encoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.encoder.encoders)))):
att_ws += [self.encoder.encoders[layer_idx].self_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in)
enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens)
loss = loss + enc_attn_loss
report_keys += [{"enc_attn_loss": enc_attn_loss.item()}]
# calculate for decoder
if "decoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.decoder.decoders)))):
att_ws += [self.decoder.decoders[layer_idx].self_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out)
dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in)
loss = loss + dec_attn_loss
report_keys += [{"dec_attn_loss": dec_attn_loss.item()}]
# calculate for encoder-decoder
if "encoder-decoder" in self.modules_applied_guided_attn:
att_ws = []
for idx, layer_idx in enumerate(reversed(range(len(self.decoder.decoders)))):
att_ws += [self.decoder.decoders[layer_idx].src_attn.attn[:, :self.num_heads_applied_guided_attn]]
if idx + 1 == self.num_layers_applied_guided_attn:
break
att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in)
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in)
loss = loss + enc_dec_attn_loss
report_keys += [{"enc_dec_attn_loss": enc_dec_attn_loss.item()}]
# report extra information
if self.use_scaled_pos_enc:
report_keys += [
{"encoder_alpha": self.encoder.embed[-1].alpha.data.item()},
{"decoder_alpha": self.decoder.embed[-1].alpha.data.item()},
]
self.reporter.report(report_keys)
return loss
inference(self, x, inference_args, spemb=None, *args, **kwargs)
¶Generate the sequence of features given the sequences of characters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Input sequence of characters (T,). |
required |
inference_args |
Namespace |
|
required |
spemb |
Tensor |
Speaker embedding vector (spk_embed_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/e2e_tts_transformer.py
def inference(self, x, inference_args, spemb=None, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Args:
x (Tensor): Input sequence of characters (T,).
inference_args (Namespace):
- threshold (float): Threshold in inference.
- minlenratio (float): Minimum length ratio in inference.
- maxlenratio (float): Maximum length ratio in inference.
spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T).
"""
# get options
threshold = inference_args.threshold
minlenratio = inference_args.minlenratio
maxlenratio = inference_args.maxlenratio
use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility
if use_att_constraint:
logging.warning("Attention constraint is not yet supported in Transformer. Not enabled.")
# forward encoder
xs = x.unsqueeze(0)
hs, _ = self.encoder(xs, None)
# integrate speaker embedding
if self.spk_embed_dim is not None:
spembs = spemb.unsqueeze(0)
hs = self._integrate_with_spk_embed(hs, spembs)
# set limits of length
maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor)
minlen = int(hs.size(1) * minlenratio / self.reduction_factor)
# initialize
idx = 0
ys = hs.new_zeros(1, 1, self.odim)
outs, probs = [], []
# forward decoder step-by-step
z_cache = self.decoder.init_state(x)
while True:
# update index
idx += 1
# calculate output and stop prob at idx-th step
y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device)
z, z_cache = self.decoder.forward_one_step(ys, y_masks, hs, cache=z_cache) # (B, adim)
outs += [self.feat_out(z).view(self.reduction_factor, self.odim)] # [(r, odim), ...]
probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...]
# update next inputs
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.odim)), dim=1) # (1, idx + 1, odim)
# get attention weights
att_ws_ = []
for name, m in self.named_modules():
if isinstance(m, MultiHeadedAttention) and "src" in name:
att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...]
if idx == 1:
att_ws = att_ws_
else:
# [(#heads, l, T), ...]
att_ws = [torch.cat([att_w, att_w_], dim=1) for att_w, att_w_ in zip(att_ws, att_ws_)]
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) # (L, odim) -> (1, L, odim) -> (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
break
# concatenate attention weights -> (#layers, #heads, L, T)
att_ws = torch.stack(att_ws, dim=0)
return outs, probs, att_ws
fastspeech
special
¶
duration_calculator
¶
Duration calculator related modules.
DurationCalculator (Module)
¶Duration calculator module for FastSpeech.
Todo
- Fix the duplicated calculation of diagonal head decision
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_calculator.py
class DurationCalculator(torch.nn.Module):
"""Duration calculator module for FastSpeech.
Todo:
* Fix the duplicated calculation of diagonal head decision
"""
def __init__(self, teacher_model):
"""Initialize duration calculator module.
Args:
teacher_model (e2e_tts_transformer.Transformer): Pretrained auto-regressive Transformer.
"""
super(DurationCalculator, self).__init__()
if isinstance(teacher_model, Transformer):
self.register_buffer("diag_head_idx", torch.tensor(-1))
elif isinstance(teacher_model, Tacotron2):
pass
else:
raise ValueError("teacher model should be the instance of e2e_tts_transformer.Transformer "
"or e2e_tts_tacotron2.Tacotron2.")
self.teacher_model = teacher_model
def forward(self, xs, ilens, ys, olens, spembs=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of character ids (B, Tmax).
ilens (Tensor): Batch of lengths of each input sequence (B,).
ys (Tensor): Batch of the padded sequence of target features (B, Lmax, odim).
olens (Tensor): Batch of lengths of each output sequence (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
Returns:
Tensor: Batch of durations (B, Tmax).
"""
if isinstance(self.teacher_model, Transformer):
att_ws = self._calculate_encoder_decoder_attentions(xs, ilens, ys, olens, spembs=spembs)
# TODO(kan-bayashi): fix this issue
# this does not work in multi-gpu case. registered buffer is not saved.
if int(self.diag_head_idx) == -1:
self._init_diagonal_head(att_ws)
att_ws = att_ws[:, self.diag_head_idx]
else:
# NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2
att_ws = self.teacher_model.calculate_all_attentions(
xs, ilens, ys, spembs=spembs, keep_tensor=True)
durations = [self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_list(durations, 0)
@staticmethod
def _calculate_duration(att_w, ilen, olen):
return torch.stack([att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)])
def _init_diagonal_head(self, att_ws):
diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1).mean(dim=0) # (H * L,)
self.register_buffer("diag_head_idx", diagonal_scores.argmax())
def _calculate_encoder_decoder_attentions(self, xs, ilens, ys, olens, spembs=None):
att_dict = self.teacher_model.calculate_all_attentions(
xs, ilens, ys, olens, spembs=spembs, skip_output=True, keep_tensor=True)
return torch.cat([att_dict[k] for k in att_dict.keys() if "src_attn" in k], dim=1) # (B, H*L, Lmax, Tmax)
__init__(self, teacher_model)
special
¶Initialize duration calculator module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
teacher_model |
e2e_tts_transformer.Transformer |
Pretrained auto-regressive Transformer. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_calculator.py
def __init__(self, teacher_model):
"""Initialize duration calculator module.
Args:
teacher_model (e2e_tts_transformer.Transformer): Pretrained auto-regressive Transformer.
"""
super(DurationCalculator, self).__init__()
if isinstance(teacher_model, Transformer):
self.register_buffer("diag_head_idx", torch.tensor(-1))
elif isinstance(teacher_model, Tacotron2):
pass
else:
raise ValueError("teacher model should be the instance of e2e_tts_transformer.Transformer "
"or e2e_tts_tacotron2.Tacotron2.")
self.teacher_model = teacher_model
forward(self, xs, ilens, ys, olens, spembs=None)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of the padded sequences of character ids (B, Tmax). |
required |
ilens |
Tensor |
Batch of lengths of each input sequence (B,). |
required |
ys |
Tensor |
Batch of the padded sequence of target features (B, Lmax, odim). |
required |
olens |
Tensor |
Batch of lengths of each output sequence (B,). |
required |
spembs |
Tensor |
Batch of speaker embedding vectors (B, spk_embed_dim). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Batch of durations (B, Tmax). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_calculator.py
def forward(self, xs, ilens, ys, olens, spembs=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of character ids (B, Tmax).
ilens (Tensor): Batch of lengths of each input sequence (B,).
ys (Tensor): Batch of the padded sequence of target features (B, Lmax, odim).
olens (Tensor): Batch of lengths of each output sequence (B,).
spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
Returns:
Tensor: Batch of durations (B, Tmax).
"""
if isinstance(self.teacher_model, Transformer):
att_ws = self._calculate_encoder_decoder_attentions(xs, ilens, ys, olens, spembs=spembs)
# TODO(kan-bayashi): fix this issue
# this does not work in multi-gpu case. registered buffer is not saved.
if int(self.diag_head_idx) == -1:
self._init_diagonal_head(att_ws)
att_ws = att_ws[:, self.diag_head_idx]
else:
# NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2
att_ws = self.teacher_model.calculate_all_attentions(
xs, ilens, ys, spembs=spembs, keep_tensor=True)
durations = [self._calculate_duration(att_w, ilen, olen) for att_w, ilen, olen in zip(att_ws, ilens, olens)]
return pad_list(durations, 0)
duration_predictor
¶
Duration predictor related modules.
DurationPredictor (Module)
¶Duration predictor module.
This is a module of duration predictor described in FastSpeech: Fast, Robust and Controllable Text to Speech
_.
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
.. _FastSpeech: Fast, Robust and Controllable Text to Speech
:
https://arxiv.org/pdf/1905.09263.pdf
Note
The calculation domain of outputs is different between in forward
and in inference
. In forward
,
the outputs are calculated in log domain but in inference
, those are calculated in linear domain.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
class DurationPredictor(torch.nn.Module):
"""Duration predictor module.
This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
Note:
The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
"""
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
"""Initilize duration predictor module.
Args:
idim (int): Input dimension.
n_layers (int, optional): Number of convolutional layers.
n_chans (int, optional): Number of channels of convolutional layers.
kernel_size (int, optional): Kernel size of convolutional layers.
dropout_rate (float, optional): Dropout rate.
offset (float, optional): Offset value to avoid nan in log domain.
"""
super(DurationPredictor, self).__init__()
self.offset = offset
self.conv = torch.nn.ModuleList()
for idx in range(n_layers):
in_chans = idim if idx == 0 else n_chans
self.conv += [torch.nn.Sequential(
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(),
LayerNorm(n_chans, dim=1),
torch.nn.Dropout(dropout_rate)
)]
self.linear = torch.nn.Linear(n_chans, 1)
def _forward(self, xs, x_masks=None, is_inference=False):
xs = xs.transpose(1, -1) # (B, idim, Tmax)
for f in self.conv:
xs = f(xs) # (B, C, Tmax)
# NOTE: calculate in log domain
xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
if is_inference:
# NOTE: calculate in linear domain
xs = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
if x_masks is not None:
xs = xs.masked_fill(x_masks, 0.0)
return xs
def forward(self, xs, x_masks=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax).
"""
return self._forward(xs, x_masks, False)
def inference(self, xs, x_masks=None):
"""Inference duration.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
Returns:
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
"""
return self._forward(xs, x_masks, True)
__init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0)
special
¶Initilize duration predictor module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Input dimension. |
required |
n_layers |
int |
Number of convolutional layers. |
2 |
n_chans |
int |
Number of channels of convolutional layers. |
384 |
kernel_size |
int |
Kernel size of convolutional layers. |
3 |
dropout_rate |
float |
Dropout rate. |
0.1 |
offset |
float |
Offset value to avoid nan in log domain. |
1.0 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0):
"""Initilize duration predictor module.
Args:
idim (int): Input dimension.
n_layers (int, optional): Number of convolutional layers.
n_chans (int, optional): Number of channels of convolutional layers.
kernel_size (int, optional): Kernel size of convolutional layers.
dropout_rate (float, optional): Dropout rate.
offset (float, optional): Offset value to avoid nan in log domain.
"""
super(DurationPredictor, self).__init__()
self.offset = offset
self.conv = torch.nn.ModuleList()
for idx in range(n_layers):
in_chans = idim if idx == 0 else n_chans
self.conv += [torch.nn.Sequential(
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(),
LayerNorm(n_chans, dim=1),
torch.nn.Dropout(dropout_rate)
)]
self.linear = torch.nn.Linear(n_chans, 1)
forward(self, xs, x_masks=None)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of input sequences (B, Tmax, idim). |
required |
x_masks |
ByteTensor |
Batch of masks indicating padded part (B, Tmax). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Batch of predicted durations in log domain (B, Tmax). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
def forward(self, xs, x_masks=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax).
"""
return self._forward(xs, x_masks, False)
inference(self, xs, x_masks=None)
¶Inference duration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of input sequences (B, Tmax, idim). |
required |
x_masks |
ByteTensor |
Batch of masks indicating padded part (B, Tmax). |
None |
Returns:
Type | Description |
---|---|
LongTensor |
Batch of predicted durations in linear domain (B, Tmax). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
def inference(self, xs, x_masks=None):
"""Inference duration.
Args:
xs (Tensor): Batch of input sequences (B, Tmax, idim).
x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
Returns:
LongTensor: Batch of predicted durations in linear domain (B, Tmax).
"""
return self._forward(xs, x_masks, True)
DurationPredictorLoss (Module)
¶Loss function module for duration predictor.
The loss value is Calculated in log domain to make it Gaussian.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
class DurationPredictorLoss(torch.nn.Module):
"""Loss function module for duration predictor.
The loss value is Calculated in log domain to make it Gaussian.
"""
def __init__(self, offset=1.0, reduction="mean"):
"""Initilize duration predictor loss module.
Args:
offset (float, optional): Offset value to avoid nan in log domain.
reduction (str): Reduction type in loss calculation.
"""
super(DurationPredictorLoss, self).__init__()
self.criterion = torch.nn.MSELoss(reduction=reduction)
self.offset = offset
def forward(self, outputs, targets):
"""Calculate forward propagation.
Args:
outputs (Tensor): Batch of prediction durations in log domain (B, T)
targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
Returns:
Tensor: Mean squared error loss value.
Note:
`outputs` is in log domain but `targets` is in linear domain.
"""
# NOTE: outputs is in log domain while targets in linear
targets = torch.log(targets.float() + self.offset)
loss = self.criterion(outputs, targets)
return loss
__init__(self, offset=1.0, reduction='mean')
special
¶Initilize duration predictor loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
offset |
float |
Offset value to avoid nan in log domain. |
1.0 |
reduction |
str |
Reduction type in loss calculation. |
'mean' |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
def __init__(self, offset=1.0, reduction="mean"):
"""Initilize duration predictor loss module.
Args:
offset (float, optional): Offset value to avoid nan in log domain.
reduction (str): Reduction type in loss calculation.
"""
super(DurationPredictorLoss, self).__init__()
self.criterion = torch.nn.MSELoss(reduction=reduction)
self.offset = offset
forward(self, outputs, targets)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
outputs |
Tensor |
Batch of prediction durations in log domain (B, T) |
required |
targets |
LongTensor |
Batch of groundtruth durations in linear domain (B, T) |
required |
Returns:
Type | Description |
---|---|
Tensor |
Mean squared error loss value. |
Note
outputs
is in log domain but targets
is in linear domain.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/duration_predictor.py
def forward(self, outputs, targets):
"""Calculate forward propagation.
Args:
outputs (Tensor): Batch of prediction durations in log domain (B, T)
targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
Returns:
Tensor: Mean squared error loss value.
Note:
`outputs` is in log domain but `targets` is in linear domain.
"""
# NOTE: outputs is in log domain while targets in linear
targets = torch.log(targets.float() + self.offset)
loss = self.criterion(outputs, targets)
return loss
length_regulator
¶
Length regulator related modules.
LengthRegulator (Module)
¶Length regulator module for feed-forward Transformer.
This is a module of length regulator described in FastSpeech: Fast, Robust and Controllable Text to Speech
_.
The length regulator expands char or phoneme-level embedding features to frame-level by repeating each
feature based on the corresponding predicted durations.
.. _FastSpeech: Fast, Robust and Controllable Text to Speech
:
https://arxiv.org/pdf/1905.09263.pdf
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/length_regulator.py
class LengthRegulator(torch.nn.Module):
"""Length regulator module for feed-forward Transformer.
This is a module of length regulator described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
The length regulator expands char or phoneme-level embedding features to frame-level by repeating each
feature based on the corresponding predicted durations.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, pad_value=0.0):
"""Initilize length regulator module.
Args:
pad_value (float, optional): Value used for padding.
"""
super(LengthRegulator, self).__init__()
self.pad_value = pad_value
def forward(self, xs, ds, ilens, alpha=1.0):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
ds (LongTensor): Batch of durations of each frame (B, T).
ilens (LongTensor): Batch of input lengths (B,).
alpha (float, optional): Alpha value to control speed of speech.
Returns:
Tensor: replicated input tensor based on durations (B, T*, D).
"""
assert alpha > 0
if alpha != 1.0:
ds = torch.round(ds.float() * alpha).long()
xs = [x[:ilen] for x, ilen in zip(xs, ilens)]
ds = [d[:ilen] for d, ilen in zip(ds, ilens)]
xs = [self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)]
return pad_list(xs, self.pad_value)
def _repeat_one_sequence(self, x, d):
"""Repeat each frame according to duration.
Examples:
>>> x = torch.tensor([[1], [2], [3]])
tensor([[1],
[2],
[3]])
>>> d = torch.tensor([1, 2, 3])
tensor([1, 2, 3])
>>> self._repeat_one_sequence(x, d)
tensor([[1],
[2],
[2],
[3],
[3],
[3]])
"""
if d.sum() == 0:
logging.warn("all of the predicted durations are 0. fill 0 with 1.")
d = d.fill_(1)
return torch.cat([x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0], dim=0)
__init__(self, pad_value=0.0)
special
¶Initilize length regulator module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pad_value |
float |
Value used for padding. |
0.0 |
forward(self, xs, ds, ilens, alpha=1.0)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of sequences of char or phoneme embeddings (B, Tmax, D). |
required |
ds |
LongTensor |
Batch of durations of each frame (B, T). |
required |
ilens |
LongTensor |
Batch of input lengths (B,). |
required |
alpha |
float |
Alpha value to control speed of speech. |
1.0 |
Returns:
Type | Description |
---|---|
Tensor |
replicated input tensor based on durations (B, T*, D). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/fastspeech/length_regulator.py
def forward(self, xs, ds, ilens, alpha=1.0):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
ds (LongTensor): Batch of durations of each frame (B, T).
ilens (LongTensor): Batch of input lengths (B,).
alpha (float, optional): Alpha value to control speed of speech.
Returns:
Tensor: replicated input tensor based on durations (B, T*, D).
"""
assert alpha > 0
if alpha != 1.0:
ds = torch.round(ds.float() * alpha).long()
xs = [x[:ilen] for x, ilen in zip(xs, ilens)]
ds = [d[:ilen] for d, ilen in zip(ds, ilens)]
xs = [self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)]
return pad_list(xs, self.pad_value)
initialization
¶
Initialization functions for RNN sequence-to-sequence models.
lecun_normal_init_parameters(module)
¶
Initialize parameters in the LeCun's manner.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/initialization.py
def lecun_normal_init_parameters(module):
"""Initialize parameters in the LeCun's manner."""
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.zero_()
elif data.dim() == 2:
# linear weight
n = data.size(1)
stdv = 1. / math.sqrt(n)
data.normal_(0, stdv)
elif data.dim() in (3, 4):
# conv weight
n = data.size(1)
for k in data.size()[2:]:
n *= k
stdv = 1. / math.sqrt(n)
data.normal_(0, stdv)
else:
raise NotImplementedError
set_forget_bias_to_one(bias)
¶
Initialize a bias vector in the forget gate with one.
uniform_init_parameters(module)
¶
Initialize parameters with an uniform distribution.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/initialization.py
def uniform_init_parameters(module):
"""Initialize parameters with an uniform distribution."""
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.uniform_(-0.1, 0.1)
elif data.dim() == 2:
# linear weight
data.uniform_(-0.1, 0.1)
elif data.dim() in (3, 4):
# conv weight
pass # use the pytorch default
else:
raise NotImplementedError
nets_utils
¶
Network related utility tools.
make_non_pad_mask(lengths, xs=None, length_dim=-1)
¶
Make mask tensor containing indices of non-padded part.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lengths |
LongTensor or List |
Batch of lengths (B,). |
required |
xs |
Tensor |
The reference tensor. If set, masks will be the same shape as this tensor. |
None |
length_dim |
int |
Dimension indicator of the above tensor. See the example. |
-1 |
Returns:
Type | Description |
---|---|
ByteTensor |
mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
make_pad_mask(lengths, xs=None, length_dim=-1)
¶
Make mask tensor containing indices of padded part.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lengths |
LongTensor or List |
Batch of lengths (B,). |
required |
xs |
Tensor |
The reference tensor. If set, masks will be the same shape as this tensor. |
None |
length_dim |
int |
Dimension indicator of the above tensor. See the example. |
-1 |
Returns:
Type | Description |
---|---|
Tensor |
Mask tensor containing indices of padded part. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) |
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def make_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None
for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
mask_by_length(xs, lengths, fill=0)
¶
Mask tensor according to length.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of input tensor (B, |
required |
lengths |
LongTensor or List |
Batch of lengths (B,). |
required |
fill |
int or float |
Value to fill masked part. |
0 |
Returns:
Type | Description |
---|---|
Tensor |
Batch of masked input tensor (B, |
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def mask_by_length(xs, lengths, fill=0):
"""Mask tensor according to length.
Args:
xs (Tensor): Batch of input tensor (B, `*`).
lengths (LongTensor or List): Batch of lengths (B,).
fill (int or float): Value to fill masked part.
Returns:
Tensor: Batch of masked input tensor (B, `*`).
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
"""
assert xs.size(0) == len(lengths)
ret = xs.data.new(*xs.size()).fill_(fill)
for i, l in enumerate(lengths):
ret[i, :l] = xs[i, :l]
return ret
pad_list(xs, pad_value)
¶
Perform padding for the list of tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
List |
List of Tensors [(T_1, |
required |
pad_value |
float |
Value for padding. |
required |
Returns:
Type | Description |
---|---|
Tensor |
Padded tensor (B, Tmax, |
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
th_accuracy(pad_outputs, pad_targets, ignore_label)
¶
Calculate accuracy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pad_outputs |
Tensor |
Prediction tensors (B * Lmax, D). |
required |
pad_targets |
LongTensor |
Target label tensors (B, Lmax, D). |
required |
ignore_label |
int |
Ignore label id. |
required |
Returns:
Type | Description |
---|---|
float |
Accuracy value (0.0 - 1.0). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def th_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0),
pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
to_device(m, x)
¶
Send tensor into the device of the module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
m |
torch.nn.Module |
Torch module. |
required |
x |
Tensor |
Torch tensor. |
required |
Returns:
Type | Description |
---|---|
Tensor |
Torch tensor located in the same place as torch module. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
assert isinstance(m, torch.nn.Module)
device = next(m.parameters()).device
return x.to(device)
to_torch_tensor(x)
¶
Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. |
required |
Returns:
Type | Description |
---|---|
Tensor or ComplexTensor |
Type converted inputs. |
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/nets_utils.py
def to_torch_tensor(x):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Args:
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
Returns:
Tensor or ComplexTensor: Type converted inputs.
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
"""
# If numpy, change to torch tensor
if isinstance(x, np.ndarray):
if x.dtype.kind == 'c':
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
return ComplexTensor(x)
else:
return torch.from_numpy(x)
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
elif isinstance(x, dict):
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
if 'real' not in x or 'imag' not in x:
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
# Relative importing because of using python3 syntax
return ComplexTensor(x['real'], x['imag'])
# If torch.Tensor, as it is
elif isinstance(x, torch.Tensor):
return x
else:
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
"but got {}".format(type(x)))
try:
from torch_complex.tensor import ComplexTensor
except Exception:
# If PY2
raise ValueError(error)
else:
# If PY3
if isinstance(x, ComplexTensor):
return x
else:
raise ValueError(error)
rnn
special
¶
attentions
¶
Attention modules for RNN.
AttAdd (Module)
¶Additive attention
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttAdd(torch.nn.Module):
"""Additive attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttAdd, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
__init__(self, eprojs, dunits, att_dim, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttAdd, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0)
¶AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
reset(self)
¶
AttCov (Module)
¶Coverage mechanism attention
Reference: Get To The Point: Summarization with Pointer-Generator Network (https://arxiv.org/abs/1704.04368)
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttCov(torch.nn.Module):
"""Coverage mechanism attention
Reference: Get To The Point: Summarization with Pointer-Generator Network
(https://arxiv.org/abs/1704.04368)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttCov, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.wvec = torch.nn.Linear(1, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
att_prev_list = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T => B x T x 1 => B x T x att_dim
cov_vec = self.wvec(cov_vec.unsqueeze(-1))
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
__init__(self, eprojs, dunits, att_dim, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttCov, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.wvec = torch.nn.Linear(1, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0)
¶AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param list att_prev_list: list of previous attention weight :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weights :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
att_prev_list = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev_list = [att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T => B x T x 1 => B x T x att_dim
cov_vec = self.wvec(cov_vec.unsqueeze(-1))
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
reset(self)
¶
AttCovLoc (Module)
¶Coverage mechanism location aware attention
This attention is a combination of coverage and location-aware attentions.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttCovLoc(torch.nn.Module):
"""Coverage mechanism location aware attention
This attention is a combination of coverage and location-aware attentions.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttCovLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev_list = [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
__init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttCovLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0)
¶AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param list att_prev_list: list of previous attention weight :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weights :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev_list = [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
reset(self)
¶
AttDot (Module)
¶Dot product attention
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttDot(torch.nn.Module):
"""Dot product attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttDot, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weight (B x T_max)
:rtype: torch.Tensor
"""
batch = enc_hs_pad.size(0)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h))
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
e = torch.sum(self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
dim=2) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
__init__(self, eprojs, dunits, att_dim, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttDot, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0)
¶AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: dummy (does not use) :param torch.Tensor att_prev: dummy (does not use) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weight (B x T_max) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weight (B x T_max)
:rtype: torch.Tensor
"""
batch = enc_hs_pad.size(0)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h))
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
e = torch.sum(self.pre_compute_enc_h * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
dim=2) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
reset(self)
¶
AttForward (Module)
¶Forward attention module.
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttForward(torch.nn.Module):
"""Forward attention module.
Reference: Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
super(AttForward, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev,
scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: attention weights of previous step
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (att_prev + att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1)
return c, w
__init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
super(AttForward, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3)
¶Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: attention weights of previous step :param float scaling: scaling parameter before applying softmax :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev,
scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: attention weights of previous step
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (att_prev + att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1)
return c, w
reset(self)
¶
AttForwardTA (Module)
¶Forward attention with transition agent module.
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eunits: # units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param int odim: output dimension
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttForwardTA(torch.nn.Module):
"""Forward attention with transition agent module.
Reference: Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eunits: # units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int odim: output dimension
"""
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
super(AttForwardTA, self).__init__()
self.mlp_enc = torch.nn.Linear(eunits, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eunits = eunits
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def reset(self):
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, out_prev,
scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B, dunits)
:param torch.Tensor att_prev: attention weights of previous step
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim)
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, dunits)
:rtype: torch.Tensor
:return: previous attention weights (B, Tmax)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update transition agent prob
self.trans_agent_prob = torch.sigmoid(
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)))
return c, w
__init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
super(AttForwardTA, self).__init__()
self.mlp_enc = torch.nn.Linear(eunits, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eunits = eunits
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, out_prev, scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3)
¶Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B, dunits) :param torch.Tensor att_prev: attention weights of previous step :param torch.Tensor out_prev: decoder outputs of previous step (B, odim) :param float scaling: scaling parameter before applying softmax :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, dunits) :rtype: torch.Tensor :return: previous attention weights (B, Tmax) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, out_prev,
scaling=1.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B, dunits)
:param torch.Tensor att_prev: attention weights of previous step
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim)
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, dunits)
:rtype: torch.Tensor
:return: previous attention weights (B, Tmax)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (self.trans_agent_prob * att_prev + (1 - self.trans_agent_prob) * att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update transition agent prob
self.trans_agent_prob = torch.sigmoid(
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)))
return c, w
reset(self)
¶
AttLoc (Module)
¶location-aware attention module.
Attention-Based Models for Speech Recognition
(https://arxiv.org/pdf/1506.07503.pdf)
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttLoc(torch.nn.Module):
"""location-aware attention module.
Reference: Attention-Based Models for Speech Recognition
(https://arxiv.org/pdf/1506.07503.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev,
scaling=2.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calcualte AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x T_max)
:param float scaling: scaling parameter before applying softmax
:param torch.Tensor forward_window: forward window size when constraining attention
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
att_prev = (1. - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
__init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0, last_attended_idx=None, backward_window=1, forward_window=3)
¶Calcualte AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: previous attention weight (B x T_max) :param float scaling: scaling parameter before applying softmax :param torch.Tensor forward_window: forward window size when constraining attention :param int last_attended_idx: index of the inputs of the last attended :param int backward_window: backward window size in attention constraint :param int forward_window: forward window size in attetion constraint :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x T_max) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev,
scaling=2.0, last_attended_idx=None, backward_window=1, forward_window=3):
"""Calcualte AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x T_max)
:param float scaling: scaling parameter before applying softmax
:param torch.Tensor forward_window: forward window size when constraining attention
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
att_prev = (1. - make_pad_mask(enc_hs_len).to(device=dec_z.device, dtype=dec_z.dtype))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# att_prev: utt x frame -> utt x 1 x 1 x frame -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(e, last_attended_idx, backward_window, forward_window)
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
reset(self)
¶
AttLoc2D (Module)
¶2D location-aware attention
This attention is an extended version of location aware attention. It take not only one frame before attention weights, but also earlier frames into account.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param int att_win: attention window size (default=5) :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttLoc2D(torch.nn.Module):
"""2D location-aware attention
This attention is an extended version of location aware attention.
It take not only one frame before attention weights, but also earlier frames into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int att_win: attention window size (default=5)
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc2D, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (att_win, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.att_win = att_win
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x att_win x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# B * [Li x att_win]
# if no bias, 0 0-pad goes 0
att_prev = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1)
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
att_conv = self.loc_conv(att_prev.unsqueeze(1))
# att_conv: B x C x 1 x Tmax -> B x Tmax x C
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax -> B x att_win x Tmax
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1)
att_prev = att_prev[:, 1:]
return c, att_prev
__init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False):
super(AttLoc2D, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (att_win, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.att_win = att_win
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0)
¶AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights (B x att_win x T_max) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x att_win x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# B * [Li x att_win]
# if no bias, 0 0-pad goes 0
att_prev = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1)
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
att_conv = self.loc_conv(att_prev.unsqueeze(1))
# att_conv: B x C x 1 x Tmax -> B x Tmax x C
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax -> B x att_win x Tmax
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1)
att_prev = att_prev[:, 1:]
return c, att_prev
reset(self)
¶
AttLocRec (Module)
¶location-aware recurrent attention
This attention is an extended version of location aware attention. With the use of RNN, it take the effect of the history of attention weights into account.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int att_dim: attention dimension :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttLocRec(torch.nn.Module):
"""location-aware recurrent attention
This attention is an extended version of location aware attention.
With the use of RNN, it take the effect of the history of attention weights into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLocRec, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0):
"""AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param tuple att_prev_states: previous attention weight and lstm states
((B, T_max), ((B, att_dim), (B, att_dim)))
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights and lstm states (w, (hx, cx))
((B, T_max), ((B, att_dim), (B, att_dim)))
:rtype: tuple
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev_states is None:
# initialize attention weight with uniform dist.
# if no bias, 0 0-pad goes 0
att_prev = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# initialize lstm states
att_h = enc_hs_pad.new_zeros(batch, self.att_dim)
att_c = enc_hs_pad.new_zeros(batch, self.att_dim)
att_states = (att_h, att_c)
else:
att_prev = att_prev_states[0]
att_states = att_prev_states[1]
# B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# apply non-linear
att_conv = F.relu(att_conv)
# B x C x 1 x T -> B x C x 1 x 1 -> B x C
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1)
att_h, att_c = self.att_lstm(att_conv, att_states)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, (w, (att_h, att_c))
__init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False):
super(AttLocRec, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0)
¶AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param tuple att_prev_states: previous attention weight and lstm states ((B, T_max), ((B, att_dim), (B, att_dim))) :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights and lstm states (w, (hx, cx)) ((B, T_max), ((B, att_dim), (B, att_dim))) :rtype: tuple
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0):
"""AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param tuple att_prev_states: previous attention weight and lstm states
((B, T_max), ((B, att_dim), (B, att_dim)))
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights and lstm states (w, (hx, cx))
((B, T_max), ((B, att_dim), (B, att_dim)))
:rtype: tuple
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev_states is None:
# initialize attention weight with uniform dist.
# if no bias, 0 0-pad goes 0
att_prev = to_device(self, (1. - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# initialize lstm states
att_h = enc_hs_pad.new_zeros(batch, self.att_dim)
att_c = enc_hs_pad.new_zeros(batch, self.att_dim)
att_states = (att_h, att_c)
else:
att_prev = att_prev_states[0]
att_states = att_prev_states[1]
# B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# apply non-linear
att_conv = F.relu(att_conv)
# B x C x 1 x T -> B x C x 1 x 1 -> B x C
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1)
att_h, att_c = self.att_lstm(att_conv, att_states)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, (w, (att_h, att_c))
reset(self)
¶
AttMultiHeadAdd (Module)
¶Multi head additive attention
Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using additive attention for each head.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttMultiHeadAdd(torch.nn.Module):
"""Multi head additive attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using additive attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadAdd, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
__init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadAdd, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev)
¶AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
reset(self)
¶
AttMultiHeadDot (Module)
¶Multi head dot product attention
Attention is all you need
(https://arxiv.org/abs/1706.03762)
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttMultiHeadDot(torch.nn.Module):
"""Multi head dot product attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadDot, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = torch.sum(self.pre_compute_k[h] * torch.tanh(self.mlp_q[h](dec_z)).view(
batch, 1, self.att_dim_k), dim=2) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
__init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadDot, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev)
¶AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
torch.tanh(self.mlp_k[h](self.enc_h)) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in six.moves.range(self.aheads):
e = torch.sum(self.pre_compute_k[h] * torch.tanh(self.mlp_q[h](dec_z)).view(
batch, 1, self.att_dim_k), dim=2) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
reset(self)
¶
AttMultiHeadLoc (Module)
¶Multi head location based attention
Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttMultiHeadLoc(torch.nn.Module):
"""Multi head location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False):
super(AttMultiHeadLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.loc_conv += [torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev += [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(
batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
__init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False):
super(AttMultiHeadLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for _ in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.loc_conv += [torch.nn.Conv2d(
1, aconv_chans, (1, 2 * aconv_filts + 1), padding=(0, aconv_filts), bias=False)]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0)
¶AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads :param float scaling: scaling parameter before applying softmax :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev += [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(
batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
reset(self)
¶
AttMultiHeadMultiResLoc (Module)
¶Multi head multi resolution location based attention
Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head. Furthermore, it uses different filter size for each head.
:param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int att_dim_k: dimension k in multi head attention :param int att_dim_v: dimension v in multi head attention :param int aconv_chans: maximum # channels of attention convolution each head use #ch = aconv_chans * (head + 1) / aheads e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class AttMultiHeadMultiResLoc(torch.nn.Module):
"""Multi head multi resolution location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
Furthermore, it uses different filter size for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: maximum # channels of attention convolution
each head use #ch = aconv_chans * (head + 1) / aheads
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False):
super(AttMultiHeadMultiResLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for h in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
afilts = aconv_filts * (h + 1) // aheads
self.loc_conv += [torch.nn.Conv2d(
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False)]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev += [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(
batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
__init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False):
super(AttMultiHeadMultiResLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for h in six.moves.range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
afilts = aconv_filts * (h + 1) // aheads
self.loc_conv += [torch.nn.Conv2d(
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False)]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev)
¶AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: decoder hidden state (B x D_dec) :param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads :return: attention weighted encoder state (B x D_enc) :rtype: torch.Tensor :return: list of previous attention weight (B x T_max) * aheads :rtype: list
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight (B x T_max) * aheads
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [
self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in six.moves.range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev += [to_device(self, mask / mask.new(enc_hs_len).unsqueeze(-1))]
c = []
w = []
for h in six.moves.range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](torch.tanh(
self.pre_compute_k[h] + att_conv + self.mlp_q[h](dec_z).view(
batch, 1, self.att_dim_k))).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(self, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float('inf'))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [torch.sum(self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1)]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
reset(self)
¶
NoAtt (Module)
¶No attention
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
class NoAtt(torch.nn.Module):
"""No attention"""
def __init__(self):
super(NoAtt, self).__init__()
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.to(self.enc_h)
self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1)
return self.c, att_prev
__init__(self)
special
¶forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev)
¶NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) :param list enc_hs_len: padded encoder hidden state length (B) :param torch.Tensor dec_z: dummy (does not use) :param torch.Tensor att_prev: dummy (does not use) :return: attention weighted encoder state (B, D_enc) :rtype: torch.Tensor :return: previous attention weights :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
mask = 1. - make_pad_mask(enc_hs_len).float()
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.to(self.enc_h)
self.c = torch.sum(self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1)
return self.c, att_prev
reset(self)
¶att_for(args, num_att=1, han_mode=False)
¶Instantiates an attention module given the program arguments
:param Namespace args: The arguments :param int num_att: number of attention modules (in multi-speaker case, it can be 2 or more) :param bool han_mode: switch on/off mode of hierarchical attention network (HAN) :rtype torch.nn.Module :return: The attention module
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def att_for(args, num_att=1, han_mode=False):
"""Instantiates an attention module given the program arguments
:param Namespace args: The arguments
:param int num_att: number of attention modules (in multi-speaker case, it can be 2 or more)
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN)
:rtype torch.nn.Module
:return: The attention module
"""
att_list = torch.nn.ModuleList()
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
aheads = getattr(args, 'aheads', None)
awin = getattr(args, 'awin', None)
aconv_chans = getattr(args, 'aconv_chans', None)
aconv_filts = getattr(args, 'aconv_filts', None)
if num_encs == 1:
for i in range(num_att):
att = initial_att(args.atype, args.eprojs, args.dunits, aheads, args.adim, awin, aconv_chans,
aconv_filts)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(args.han_type, args.eprojs, args.dunits, args.han_heads, args.han_dim,
args.han_win, args.han_conv_chans, args.han_conv_filts, han_mode=True)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(args.atype[idx], args.eprojs, args.dunits, aheads[idx], args.adim[idx],
awin[idx], aconv_chans[idx], aconv_filts[idx])
att_list.append(att)
else:
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
return att_list
att_to_numpy(att_ws, att)
¶Converts attention weights to a numpy array given the attention
:param list att_ws: The attention weights :param torch.nn.Module att: The attention :rtype: np.ndarray :return: The numpy array of the attention weights
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def att_to_numpy(att_ws, att):
"""Converts attention weights to a numpy array given the attention
:param list att_ws: The attention weights
:param torch.nn.Module att: The attention
:rtype: np.ndarray
:return: The numpy array of the attention weights
"""
# convert to numpy array with the shape (B, Lmax, Tmax)
if isinstance(att, AttLoc2D):
# att_ws => list of previous concate attentions
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttCov, AttCovLoc)):
# att_ws => list of list of previous attentions
att_ws = torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
elif isinstance(att, AttLocRec):
# att_ws => list of tuple of attention and hidden states
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc)):
# att_ws => list of list of each head attention
n_heads = len(att_ws[0])
att_ws_sorted_by_head = []
for h in six.moves.range(n_heads):
att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1)
att_ws_sorted_by_head += [att_ws_head]
att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy()
else:
# att_ws => list of attentions
att_ws = torch.stack(att_ws, dim=1).cpu().numpy()
return att_ws
initial_att(atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False)
¶Instantiates a single attention module
:param str atype: attention type :param int eprojs: # projection-units of encoder :param int dunits: # units of decoder :param int aheads: # heads of multi head attention :param int adim: attention dimension :param int awin: attention window size :param int aconv_chans: # channels of attention convolution :param int aconv_filts: filter size of attention convolution :param bool han_mode: flag to swith on mode of hierarchical attention :return: The attention module
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/attentions.py
def initial_att(atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False):
"""Instantiates a single attention module
:param str atype: attention type
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int adim: attention dimension
:param int awin: attention window size
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
:return: The attention module
"""
if atype == 'noatt':
att = NoAtt()
elif atype == 'dot':
att = AttDot(eprojs, dunits, adim, han_mode)
elif atype == 'add':
att = AttAdd(eprojs, dunits, adim, han_mode)
elif atype == 'location':
att = AttLoc(eprojs, dunits,
adim, aconv_chans, aconv_filts, han_mode)
elif atype == 'location2d':
att = AttLoc2D(eprojs, dunits,
adim, awin, aconv_chans, aconv_filts, han_mode)
elif atype == 'location_recurrent':
att = AttLocRec(eprojs, dunits,
adim, aconv_chans, aconv_filts, han_mode)
elif atype == 'coverage':
att = AttCov(eprojs, dunits, adim, han_mode)
elif atype == 'coverage_location':
att = AttCovLoc(eprojs, dunits, adim,
aconv_chans, aconv_filts, han_mode)
elif atype == 'multi_head_dot':
att = AttMultiHeadDot(eprojs, dunits,
aheads, adim, adim, han_mode)
elif atype == 'multi_head_add':
att = AttMultiHeadAdd(eprojs, dunits,
aheads, adim, adim, han_mode)
elif atype == 'multi_head_loc':
att = AttMultiHeadLoc(eprojs, dunits,
aheads, adim, adim,
aconv_chans, aconv_filts, han_mode)
elif atype == 'multi_head_multi_res_loc':
att = AttMultiHeadMultiResLoc(eprojs, dunits,
aheads, adim, adim,
aconv_chans, aconv_filts, han_mode)
return att
decoders
¶
CTC_SCORING_RATIO
¶MAX_DECODER_OUTPUT
¶
Decoder (Module, ScorerInterface)
¶Decoder module
:param int eprojs: encoder projection units :param int odim: dimension of outputs :param str dtype: gru or lstm :param int dlayers: decoder layers :param int dunits: decoder units :param int sos: start of sequence symbol id :param int eos: end of sequence symbol id :param torch.nn.Module att: attention module :param int verbose: verbose level :param list char_list: list of character strings :param ndarray labeldist: distribution of label smoothing :param float lsm_weight: label smoothing weight :param float sampling_probability: scheduled sampling probability :param float dropout: dropout rate :param float context_residual: if True, use context vector for token generation :param float replace_sos: use for multilingual (speech/text) translation
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
class Decoder(torch.nn.Module, ScorerInterface):
"""Decoder module
:param int eprojs: encoder projection units
:param int odim: dimension of outputs
:param str dtype: gru or lstm
:param int dlayers: decoder layers
:param int dunits: decoder units
:param int sos: start of sequence symbol id
:param int eos: end of sequence symbol id
:param torch.nn.Module att: attention module
:param int verbose: verbose level
:param list char_list: list of character strings
:param ndarray labeldist: distribution of label smoothing
:param float lsm_weight: label smoothing weight
:param float sampling_probability: scheduled sampling probability
:param float dropout: dropout rate
:param float context_residual: if True, use context vector for token generation
:param float replace_sos: use for multilingual (speech/text) translation
"""
def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, eos, att, verbose=0,
char_list=None, labeldist=None, lsm_weight=0., sampling_probability=0.0,
dropout=0.0, context_residual=False, replace_sos=False, num_encs=1):
torch.nn.Module.__init__(self)
self.dtype = dtype
self.dunits = dunits
self.dlayers = dlayers
self.context_residual = context_residual
self.embed = torch.nn.Embedding(odim, dunits)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(dunits + eprojs, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits + eprojs,
dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in six.moves.range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(dunits, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits, dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
self.ignore_id = -1
if context_residual:
self.output = torch.nn.Linear(dunits + eprojs, odim)
else:
self.output = torch.nn.Linear(dunits, odim)
self.loss = None
self.att = att
self.dunits = dunits
self.sos = sos
self.eos = eos
self.odim = odim
self.verbose = verbose
self.char_list = char_list
# for label smoothing
self.labeldist = labeldist
self.vlabeldist = None
self.lsm_weight = lsm_weight
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual E2E-ST
self.replace_sos = replace_sos
self.logzero = -10000000000.0
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in six.moves.range(1, self.dlayers):
z_list[l], c_list[l] = self.decoder[l](
self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l]))
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for l in six.moves.range(1, self.dlayers):
z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l])
return z_list, c_list
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
"""Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index indicates the index of decoding stream.
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy
:rtype: float
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# attention index for the attention module
# in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get dim, length info
batch = ys_out_pad.size(0)
olength = ys_out_pad.size(1)
for idx in range(self.num_encs):
logging.info(
self.__class__.__name__ + 'Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs,
idx + 1, hlens[idx]))
logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out]))
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](hs_pad[idx], hlens[idx],
self.dropout_dec[0](z_list[0]), att_w_list[idx])
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](hs_pad_han, hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs])
if i > 0 and random.random() < self.sampling_probability:
logging.info(' scheduled sampling ')
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
# compute loss
y_all = self.output(z_all)
if LooseVersion(torch.__version__) < LooseVersion('1.0'):
reduction_str = 'elementwise_mean'
else:
reduction_str = 'mean'
self.loss = F.cross_entropy(y_all, ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction=reduction_str)
# compute perplexity
ppl = math.exp(self.loss.item())
# -1: eos, which is removed in the loss computation
self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
ys_true.detach().cpu().numpy()):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist))
loss_reg = - torch.sum((F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in)
self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc, ppl
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
"""beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs)
[in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ]
:param torch.Tensor lpz: ctc log softmax output (T, odim)
[in multi-encoder case, list of torch.Tensor, [(T1, odim), (T2, odim), ...] ]
:param Namespace recog_args: argument Namespace containing options
:param char_list: list of character strings
:param torch.nn.Module rnnlm: language module
:param int strm_idx: stream index for speaker parallel attention in multi-speaker case
:return: N-best decoding results
:rtype: list of dicts
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
for idx in range(self.num_encs):
logging.info('Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs, idx + 1, h[0].size(0)))
att_idx = min(strm_idx, len(self.att) - 1)
# initialization
c_list = [self.zero_state(h[0].unsqueeze(0))]
z_list = [self.zero_state(h[0].unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(h[0].unsqueeze(0)))
z_list.append(self.zero_state(h[0].unsqueeze(0)))
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
if lpz[0] is not None and self.num_encs > 1:
# weights-ctc, e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(recog_args.weights_ctc_dec) # normalize
logging.info('ctc weights (decoding): ' + ' '.join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
# preprate sos
if self.replace_sos and recog_args.tgt_lang:
y = char_list.index(recog_args.tgt_lang)
else:
y = self.sos
logging.info('<sos> index: ' + str(y))
logging.info('<sos> mark: ' + char_list[y])
vy = h[0].new_zeros(1).long()
maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
if recog_args.maxlenratio != 0:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * maxlen))
minlen = int(recog_args.minlenratio * maxlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list,
'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
if lpz[0] is not None:
ctc_prefix_score = [CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) for idx in
range(self.num_encs)]
hyp['ctc_state_prev'] = [ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)]
hyp['ctc_score_prev'] = [0.0] * self.num_encs
if ctc_weight != 1.0:
# pre-pruning based on attention scores
ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz[0].shape[-1]
hyps = [hyp]
ended_hyps = []
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
vy.unsqueeze(1)
vy[0] = hyp['yseq'][i]
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
ey.unsqueeze(0)
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](h[0].unsqueeze(0), [h[0].size(0)],
self.dropout_dec[0](hyp['z_prev'][0]), hyp['a_prev'])
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](h[idx].unsqueeze(0), [h[idx].size(0)],
self.dropout_dec[0](hyp['z_prev'][0]),
hyp['a_prev'][idx])
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](h_han, [self.num_encs],
self.dropout_dec[0](hyp['z_prev'][0]),
hyp['a_prev'][self.num_encs])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev'])
# get nbest local scores and their ids
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_att_scores = F.log_softmax(logits, dim=1)
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz[0] is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1)
ctc_scores, ctc_states = [None] * self.num_encs, [None] * self.num_encs
for idx in range(self.num_encs):
ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'][idx])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
if self.num_encs == 1:
local_scores += ctc_weight * torch.from_numpy(ctc_scores[0] - hyp['ctc_score_prev'][0])
else:
for idx in range(self.num_encs):
local_scores += ctc_weight * weights_ctc_dec[idx] * torch.from_numpy(
ctc_scores[idx] - hyp['ctc_score_prev'][idx])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
# [:] is needed!
new_hyp['z_prev'] = z_list[:]
new_hyp['c_prev'] = c_list[:]
if self.num_encs == 1:
new_hyp['a_prev'] = att_w[:]
else:
new_hyp['a_prev'] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz[0] is not None:
new_hyp['ctc_state_prev'] = [ctc_states[idx][joint_best_ids[0, j]] for idx in
range(self.num_encs)]
new_hyp['ctc_score_prev'] = [ctc_scores[idx][joint_best_ids[0, j]] for idx in
range(self.num_encs)]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypotheses: ' + str(len(hyps)))
logging.debug(
'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last position in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remaining hypotheses: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypotheses: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
if self.num_encs == 1:
return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
else:
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
# remove sos
return nbest_hyps
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None,
normalize_score=True, strm_idx=0, lang_ids=None):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
hlens = [hlens]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
att_idx = min(strm_idx, len(self.att) - 1)
for idx in range(self.num_encs):
logging.info(
'Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs, idx + 1, h[idx].size(1)))
h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
# search params
batch = len(hlens[0])
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
att_weight = 1.0 - ctc_weight
ctc_margin = getattr(recog_args, "ctc_window_margin", 0) # use getattr to keep compatibility
# weights-ctc, e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
if lpz[0] is not None and self.num_encs > 1:
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(recog_args.weights_ctc_dec) # normalize
logging.info('ctc weights (decoding): ' + ' '.join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
n_bb = batch * beam
pad_b = to_device(self, torch.arange(batch) * beam).view(-1, 1)
max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
if recog_args.maxlenratio == 0:
maxlen = max_hlen
else:
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
minlen = int(recog_args.minlenratio * max_hlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialization
c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
vscores = to_device(self, torch.zeros(batch, beam))
rnnlm_state = None
if self.num_encs == 1:
a_prev = [None]
att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
self.att[att_idx].reset() # reset pre-computation of h
else:
a_prev = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
if self.replace_sos and recog_args.tgt_lang:
logging.info('<sos> index: ' + str(char_list.index(recog_args.tgt_lang)))
logging.info('<sos> mark: ' + recog_args.tgt_lang)
yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
elif lang_ids is not None:
# NOTE: used for evaluation during training
yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
else:
logging.info('<sos> index: ' + str(self.sos))
logging.info('<sos> mark: ' + char_list[self.sos])
yseq = [[self.sos] for _ in six.moves.range(n_bb)]
accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
stop_search = [False for _ in six.moves.range(batch)]
nbest_hyps = [[] for _ in six.moves.range(batch)]
ended_hyps = [[] for _ in range(batch)]
exp_hlens = [hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() for idx in
range(self.num_encs)]
exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
exp_h = [h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)]
exp_h = [exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) for idx in range(self.num_encs)]
if lpz[0] is not None:
scoring_ratio = CTC_SCORING_RATIO if att_weight > 0.0 and not lpz[0].is_cuda else 0
ctc_scorer = [CTCPrefixScoreTH(lpz[idx], hlens[idx], 0, self.eos, beam,
scoring_ratio, margin=ctc_margin) for idx in range(self.num_encs)]
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq)))
ey = self.dropout_emb(self.embed(vy))
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0])
att_w_list = [att_w]
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](exp_h[idx], exp_hlens[idx],
self.dropout_dec[0](z_prev[0]), a_prev[idx])
exp_h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](exp_h_han, [self.num_encs] * n_bb,
self.dropout_dec[0](z_prev[0]),
a_prev[self.num_encs])
ey = torch.cat((ey, att_c), dim=1)
# attention decoder
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_scores = att_weight * F.log_softmax(logits, dim=1)
# rnnlm
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
local_scores = local_scores + recog_args.lm_weight * local_lm_scores
# ctc
if ctc_scorer[0]:
for idx in range(self.num_encs):
att_w = att_w_list[idx]
att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
ctc_state[idx], local_ctc_scores = ctc_scorer[idx](yseq, ctc_state[idx], local_scores, att_w_)
local_scores = local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
local_scores = local_scores.view(batch, beam, self.odim)
if i == 0:
local_scores[:, 1:, :] = self.logzero
# accumulate scores
eos_vscores = local_scores[:, :, self.eos] + vscores
vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
vscores[:, :, self.eos] = self.logzero
vscores = (vscores + local_scores).view(batch, -1)
# global pruning
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist()
y_prev = yseq[:][:]
yseq = self._index_select_list(yseq, accum_padded_beam_ids)
yseq = self._append_ids(yseq, accum_odim_ids)
vscores = accum_best_scores
vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids))
a_prev = []
num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
for idx in range(num_atts):
if isinstance(att_w_list[idx], torch.Tensor):
_a_prev = torch.index_select(att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx)
elif isinstance(att_w_list[idx], list):
# handle the case of multi-head attention
_a_prev = [torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w_list[idx]]
else:
# handle the case of location_recurrent when return is a tuple
_a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
_h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
_c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
_a_prev = (_a_prev_, (_h_prev_, _c_prev_))
a_prev.append(_a_prev)
z_prev = [torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
c_prev = [torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
# pick ended hyps
if i > minlen:
k = 0
penalty_i = (i + 1) * penalty
thr = accum_best_scores[:, -1]
for samp_i in six.moves.range(batch):
if stop_search[samp_i]:
k = k + beam
continue
for beam_j in six.moves.range(beam):
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
yk = y_prev[k][:]
yk.append(self.eos)
if len(yk) < min(hlens[idx][samp_i] for idx in range(self.num_encs)):
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
if rnnlm:
_vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
_score = _vscore.data.cpu().numpy()
ended_hyps[samp_i].append({'yseq': yk, 'vscore': _vscore, 'score': _score})
k = k + 1
# end detection
stop_search = [stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
for samp_i in six.moves.range(batch)]
stop_search_summary = list(set(stop_search))
if len(stop_search_summary) == 1 and stop_search_summary[0]:
break
if rnnlm:
rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
if ctc_scorer[0]:
for idx in range(self.num_encs):
ctc_state[idx] = ctc_scorer[idx].index_select_state(ctc_state[idx], accum_best_ids)
torch.cuda.empty_cache()
dummy_hyps = [{'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')])}]
ended_hyps = [ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
for samp_i in six.moves.range(batch)]
if normalize_score:
for samp_i in six.moves.range(batch):
for x in ended_hyps[samp_i]:
x['score'] /= len(x['yseq'])
nbest_hyps = [sorted(ended_hyps[samp_i], key=lambda x: x['score'],
reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)]
for samp_i in six.moves.range(batch)]
return nbest_hyps
def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
"""Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index for parallel speaker attention in multi-speaker case
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) multi-encoder case => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
3) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlen = [hlen]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
att_idx = min(strm_idx, len(self.att) - 1)
# hlen should be list of list of integer
hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get length info
olength = ys_out_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
att_ws = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w)
att_ws.append(att_w)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](hs_pad[idx], hlen[idx],
self.dropout_dec[0](z_list[0]), att_w_list[idx])
hs_pad_han = torch.stack(att_c_list, dim=1)
hlen_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](hs_pad_han, hlen_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs])
att_ws.append(att_w_list)
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.num_encs == 1:
# convert to numpy array with the shape (B, Lmax, Tmax)
att_ws = att_to_numpy(att_ws, self.att[att_idx])
else:
_att_ws = []
for idx, ws in enumerate(zip(*att_ws)):
ws = att_to_numpy(ws, self.att[idx])
_att_ws.append(ws)
att_ws = _att_ws
return att_ws
@staticmethod
def _get_last_yseq(exp_yseq):
last = []
for y_seq in exp_yseq:
last.append(y_seq[-1])
return last
@staticmethod
def _append_ids(yseq, ids):
if isinstance(ids, list):
for i, j in enumerate(ids):
yseq[i].append(j)
else:
for i in range(len(yseq)):
yseq[i].append(ids)
return yseq
@staticmethod
def _index_select_list(yseq, lst):
new_yseq = []
for l in lst:
new_yseq.append(yseq[l][:])
return new_yseq
@staticmethod
def _index_select_lm_state(rnnlm_state, dim, vidx):
if isinstance(rnnlm_state, dict):
new_state = {}
for k, v in rnnlm_state.items():
new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
elif isinstance(rnnlm_state, list):
new_state = []
for i in vidx:
new_state.append(rnnlm_state[int(i)][:])
return new_state
# scorer interface methods
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att) - 1)
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
return dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=a, workspace=(att_idx, z_list, c_list))
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
x[0].unsqueeze(0), [x[0].size(0)],
self.dropout_dec[0](state['z_prev'][0]), state['a_prev'])
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att[idx](x[idx].unsqueeze(0), [x[idx].size(0)],
self.dropout_dec[0](state['z_prev'][0]),
state['a_prev'][idx])
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att[self.num_encs](h_han, [self.num_encs],
self.dropout_dec[0](state['z_prev'][0]),
state['a_prev'][self.num_encs])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, state['z_prev'], state['c_prev'])
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return logp, dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=att_w, workspace=(att_idx, z_list, c_list))
__init__(self, eprojs, odim, dtype, dlayers, dunits, sos, eos, att, verbose=0, char_list=None, labeldist=None, lsm_weight=0.0, sampling_probability=0.0, dropout=0.0, context_residual=False, replace_sos=False, num_encs=1)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def __init__(self, eprojs, odim, dtype, dlayers, dunits, sos, eos, att, verbose=0,
char_list=None, labeldist=None, lsm_weight=0., sampling_probability=0.0,
dropout=0.0, context_residual=False, replace_sos=False, num_encs=1):
torch.nn.Module.__init__(self)
self.dtype = dtype
self.dunits = dunits
self.dlayers = dlayers
self.context_residual = context_residual
self.embed = torch.nn.Embedding(odim, dunits)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(dunits + eprojs, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits + eprojs,
dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in six.moves.range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(dunits, dunits) if self.dtype == "lstm" else torch.nn.GRUCell(dunits, dunits)]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
self.ignore_id = -1
if context_residual:
self.output = torch.nn.Linear(dunits + eprojs, odim)
else:
self.output = torch.nn.Linear(dunits, odim)
self.loss = None
self.att = att
self.dunits = dunits
self.sos = sos
self.eos = eos
self.odim = odim
self.verbose = verbose
self.char_list = char_list
# for label smoothing
self.labeldist = labeldist
self.vlabeldist = None
self.lsm_weight = lsm_weight
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual E2E-ST
self.replace_sos = replace_sos
self.logzero = -10000000000.0
calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None)
¶Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) [in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] :param torch.Tensor hlen: batch of lengths of hidden state sequences (B) [in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index for parallel speaker attention in multi-speaker case :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) :return: attention weights with the following shape, 1) multi-head case => attention weights (B, H, Lmax, Tmax), 2) multi-encoder case => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] 3) other case => attention weights (B, Lmax, Tmax). :rtype: float ndarray
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
"""Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index for parallel speaker attention in multi-speaker case
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) multi-encoder case => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
3) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlen = [hlen]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
att_idx = min(strm_idx, len(self.att) - 1)
# hlen should be list of list of integer
hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get length info
olength = ys_out_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
att_ws = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w)
att_ws.append(att_w)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](hs_pad[idx], hlen[idx],
self.dropout_dec[0](z_list[0]), att_w_list[idx])
hs_pad_han = torch.stack(att_c_list, dim=1)
hlen_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](hs_pad_han, hlen_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs])
att_ws.append(att_w_list)
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.num_encs == 1:
# convert to numpy array with the shape (B, Lmax, Tmax)
att_ws = att_to_numpy(att_ws, self.att[att_idx])
else:
_att_ws = []
for idx, ws in enumerate(zip(*att_ws)):
ws = att_to_numpy(ws, self.att[idx])
_att_ws.append(ws)
att_ws = _att_ws
return att_ws
forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None)
¶Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) [in multi-encoder case, list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) [in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ] :param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax) :param int strm_idx: stream index indicates the index of decoding stream. :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) :return: attention loss value :rtype: torch.Tensor :return: accuracy :rtype: float
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
"""Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor, [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor, [(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx: stream index indicates the index of decoding stream.
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy
:rtype: float
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# attention index for the attention module
# in SPA (speaker parallel attention), att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get dim, length info
batch = ys_out_pad.size(0)
olength = ys_out_pad.size(1)
for idx in range(self.num_encs):
logging.info(
self.__class__.__name__ + 'Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs,
idx + 1, hlens[idx]))
logging.info(self.__class__.__name__ + ' output lengths: ' + str([y.size(0) for y in ys_out]))
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in six.moves.range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](hs_pad[idx], hlens[idx],
self.dropout_dec[0](z_list[0]), att_w_list[idx])
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](hs_pad_han, hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs])
if i > 0 and random.random() < self.sampling_probability:
logging.info(' scheduled sampling ')
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(self, z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
# compute loss
y_all = self.output(z_all)
if LooseVersion(torch.__version__) < LooseVersion('1.0'):
reduction_str = 'elementwise_mean'
else:
reduction_str = 'mean'
self.loss = F.cross_entropy(y_all, ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction=reduction_str)
# compute perplexity
ppl = math.exp(self.loss.item())
# -1: eos, which is removed in the loss computation
self.loss *= (np.mean([len(x) for x in ys_in]) - 1)
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
logging.info('att loss:' + ''.join(str(self.loss.item()).split('\n')))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(enumerate(ys_hat.detach().cpu().numpy()),
ys_true.detach().cpu().numpy()):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = to_device(self, torch.from_numpy(self.labeldist))
loss_reg = - torch.sum((F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0) / len(ys_in)
self.loss = (1. - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc, ppl
init_state(self, x)
¶Get an initial state for decoding (optional).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
The encoded feature tensor |
required |
Returns: initial state
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att) - 1)
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
return dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=a, workspace=(att_idx, z_list, c_list))
recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0)
¶beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs) [in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ] :param torch.Tensor lpz: ctc log softmax output (T, odim) [in multi-encoder case, list of torch.Tensor, [(T1, odim), (T2, odim), ...] ] :param Namespace recog_args: argument Namespace containing options :param char_list: list of character strings :param torch.nn.Module rnnlm: language module :param int strm_idx: stream index for speaker parallel attention in multi-speaker case :return: N-best decoding results :rtype: list of dicts
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
"""beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs)
[in multi-encoder case, list of torch.Tensor, [(T1, eprojs), (T2, eprojs), ...] ]
:param torch.Tensor lpz: ctc log softmax output (T, odim)
[in multi-encoder case, list of torch.Tensor, [(T1, odim), (T2, odim), ...] ]
:param Namespace recog_args: argument Namespace containing options
:param char_list: list of character strings
:param torch.nn.Module rnnlm: language module
:param int strm_idx: stream index for speaker parallel attention in multi-speaker case
:return: N-best decoding results
:rtype: list of dicts
"""
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
for idx in range(self.num_encs):
logging.info('Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs, idx + 1, h[0].size(0)))
att_idx = min(strm_idx, len(self.att) - 1)
# initialization
c_list = [self.zero_state(h[0].unsqueeze(0))]
z_list = [self.zero_state(h[0].unsqueeze(0))]
for _ in six.moves.range(1, self.dlayers):
c_list.append(self.zero_state(h[0].unsqueeze(0)))
z_list.append(self.zero_state(h[0].unsqueeze(0)))
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
if lpz[0] is not None and self.num_encs > 1:
# weights-ctc, e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(recog_args.weights_ctc_dec) # normalize
logging.info('ctc weights (decoding): ' + ' '.join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
# preprate sos
if self.replace_sos and recog_args.tgt_lang:
y = char_list.index(recog_args.tgt_lang)
else:
y = self.sos
logging.info('<sos> index: ' + str(y))
logging.info('<sos> mark: ' + char_list[y])
vy = h[0].new_zeros(1).long()
maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
if recog_args.maxlenratio != 0:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * maxlen))
minlen = int(recog_args.minlenratio * maxlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list,
'z_prev': z_list, 'a_prev': a, 'rnnlm_prev': None}
else:
hyp = {'score': 0.0, 'yseq': [y], 'c_prev': c_list, 'z_prev': z_list, 'a_prev': a}
if lpz[0] is not None:
ctc_prefix_score = [CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) for idx in
range(self.num_encs)]
hyp['ctc_state_prev'] = [ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)]
hyp['ctc_score_prev'] = [0.0] * self.num_encs
if ctc_weight != 1.0:
# pre-pruning based on attention scores
ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz[0].shape[-1]
hyps = [hyp]
ended_hyps = []
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
hyps_best_kept = []
for hyp in hyps:
vy.unsqueeze(1)
vy[0] = hyp['yseq'][i]
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
ey.unsqueeze(0)
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](h[0].unsqueeze(0), [h[0].size(0)],
self.dropout_dec[0](hyp['z_prev'][0]), hyp['a_prev'])
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](h[idx].unsqueeze(0), [h[idx].size(0)],
self.dropout_dec[0](hyp['z_prev'][0]),
hyp['a_prev'][idx])
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](h_han, [self.num_encs],
self.dropout_dec[0](hyp['z_prev'][0]),
hyp['a_prev'][self.num_encs])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp['z_prev'], hyp['c_prev'])
# get nbest local scores and their ids
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_att_scores = F.log_softmax(logits, dim=1)
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp['rnnlm_prev'], vy)
local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
if lpz[0] is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1)
ctc_scores, ctc_states = [None] * self.num_encs, [None] * self.num_encs
for idx in range(self.num_encs):
ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
hyp['yseq'], local_best_ids[0], hyp['ctc_state_prev'][idx])
local_scores = \
(1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
if self.num_encs == 1:
local_scores += ctc_weight * torch.from_numpy(ctc_scores[0] - hyp['ctc_score_prev'][0])
else:
for idx in range(self.num_encs):
local_scores += ctc_weight * weights_ctc_dec[idx] * torch.from_numpy(
ctc_scores[idx] - hyp['ctc_score_prev'][idx])
if rnnlm:
local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
# [:] is needed!
new_hyp['z_prev'] = z_list[:]
new_hyp['c_prev'] = c_list[:]
if self.num_encs == 1:
new_hyp['a_prev'] = att_w[:]
else:
new_hyp['a_prev'] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = [0] * (1 + len(hyp['yseq']))
new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq']
new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp['rnnlm_prev'] = rnnlm_state
if lpz[0] is not None:
new_hyp['ctc_state_prev'] = [ctc_states[idx][joint_best_ids[0, j]] for idx in
range(self.num_encs)]
new_hyp['ctc_score_prev'] = [ctc_scores[idx][joint_best_ids[0, j]] for idx in
range(self.num_encs)]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug('number of pruned hypotheses: ' + str(len(hyps)))
logging.debug(
'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info('adding <eos> in the last position in the loop')
for hyp in hyps:
hyp['yseq'].append(self.eos)
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp['yseq']) > minlen:
hyp['score'] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp['score'] += recog_args.lm_weight * rnnlm.final(
hyp['rnnlm_prev'])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info('end detected at %d', i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug('remaining hypotheses: ' + str(len(hyps)))
else:
logging.info('no hypothesis. Finish decoding.')
break
for hyp in hyps:
logging.debug(
'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]]))
logging.debug('number of ended hypotheses: ' + str(len(ended_hyps)))
nbest_hyps = sorted(
ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning('there is no N-best results, perform recognition again with smaller minlenratio.')
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
if self.num_encs == 1:
return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
else:
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info('total log probability: ' + str(nbest_hyps[0]['score']))
logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq'])))
# remove sos
return nbest_hyps
recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None, normalize_score=True, strm_idx=0, lang_ids=None)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def recognize_beam_batch(self, h, hlens, lpz, recog_args, char_list, rnnlm=None,
normalize_score=True, strm_idx=0, lang_ids=None):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
hlens = [hlens]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
att_idx = min(strm_idx, len(self.att) - 1)
for idx in range(self.num_encs):
logging.info(
'Number of Encoder:{}; enc{}: input lengths: {}.'.format(self.num_encs, idx + 1, h[idx].size(1)))
h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
# search params
batch = len(hlens[0])
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
att_weight = 1.0 - ctc_weight
ctc_margin = getattr(recog_args, "ctc_window_margin", 0) # use getattr to keep compatibility
# weights-ctc, e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
if lpz[0] is not None and self.num_encs > 1:
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(recog_args.weights_ctc_dec) # normalize
logging.info('ctc weights (decoding): ' + ' '.join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
n_bb = batch * beam
pad_b = to_device(self, torch.arange(batch) * beam).view(-1, 1)
max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
if recog_args.maxlenratio == 0:
maxlen = max_hlen
else:
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
minlen = int(recog_args.minlenratio * max_hlen)
logging.info('max output length: ' + str(maxlen))
logging.info('min output length: ' + str(minlen))
# initialization
c_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_prev = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
c_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
z_list = [to_device(self, torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
vscores = to_device(self, torch.zeros(batch, beam))
rnnlm_state = None
if self.num_encs == 1:
a_prev = [None]
att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
self.att[att_idx].reset() # reset pre-computation of h
else:
a_prev = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
if self.replace_sos and recog_args.tgt_lang:
logging.info('<sos> index: ' + str(char_list.index(recog_args.tgt_lang)))
logging.info('<sos> mark: ' + recog_args.tgt_lang)
yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
elif lang_ids is not None:
# NOTE: used for evaluation during training
yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
else:
logging.info('<sos> index: ' + str(self.sos))
logging.info('<sos> mark: ' + char_list[self.sos])
yseq = [[self.sos] for _ in six.moves.range(n_bb)]
accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)]
stop_search = [False for _ in six.moves.range(batch)]
nbest_hyps = [[] for _ in six.moves.range(batch)]
ended_hyps = [[] for _ in range(batch)]
exp_hlens = [hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() for idx in
range(self.num_encs)]
exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
exp_h = [h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)]
exp_h = [exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) for idx in range(self.num_encs)]
if lpz[0] is not None:
scoring_ratio = CTC_SCORING_RATIO if att_weight > 0.0 and not lpz[0].is_cuda else 0
ctc_scorer = [CTCPrefixScoreTH(lpz[idx], hlens[idx], 0, self.eos, beam,
scoring_ratio, margin=ctc_margin) for idx in range(self.num_encs)]
for i in six.moves.range(maxlen):
logging.debug('position ' + str(i))
vy = to_device(self, torch.LongTensor(self._get_last_yseq(yseq)))
ey = self.dropout_emb(self.embed(vy))
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0])
att_w_list = [att_w]
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](exp_h[idx], exp_hlens[idx],
self.dropout_dec[0](z_prev[0]), a_prev[idx])
exp_h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](exp_h_han, [self.num_encs] * n_bb,
self.dropout_dec[0](z_prev[0]),
a_prev[self.num_encs])
ey = torch.cat((ey, att_c), dim=1)
# attention decoder
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_scores = att_weight * F.log_softmax(logits, dim=1)
# rnnlm
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
local_scores = local_scores + recog_args.lm_weight * local_lm_scores
# ctc
if ctc_scorer[0]:
for idx in range(self.num_encs):
att_w = att_w_list[idx]
att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
ctc_state[idx], local_ctc_scores = ctc_scorer[idx](yseq, ctc_state[idx], local_scores, att_w_)
local_scores = local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
local_scores = local_scores.view(batch, beam, self.odim)
if i == 0:
local_scores[:, 1:, :] = self.logzero
# accumulate scores
eos_vscores = local_scores[:, :, self.eos] + vscores
vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
vscores[:, :, self.eos] = self.logzero
vscores = (vscores + local_scores).view(batch, -1)
# global pruning
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
accum_padded_beam_ids = (torch.div(accum_best_ids, self.odim) + pad_b).view(-1).data.cpu().tolist()
y_prev = yseq[:][:]
yseq = self._index_select_list(yseq, accum_padded_beam_ids)
yseq = self._append_ids(yseq, accum_odim_ids)
vscores = accum_best_scores
vidx = to_device(self, torch.LongTensor(accum_padded_beam_ids))
a_prev = []
num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
for idx in range(num_atts):
if isinstance(att_w_list[idx], torch.Tensor):
_a_prev = torch.index_select(att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx)
elif isinstance(att_w_list[idx], list):
# handle the case of multi-head attention
_a_prev = [torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) for att_w_one in att_w_list[idx]]
else:
# handle the case of location_recurrent when return is a tuple
_a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
_h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
_c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
_a_prev = (_a_prev_, (_h_prev_, _c_prev_))
a_prev.append(_a_prev)
z_prev = [torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
c_prev = [torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)]
# pick ended hyps
if i > minlen:
k = 0
penalty_i = (i + 1) * penalty
thr = accum_best_scores[:, -1]
for samp_i in six.moves.range(batch):
if stop_search[samp_i]:
k = k + beam
continue
for beam_j in six.moves.range(beam):
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
yk = y_prev[k][:]
yk.append(self.eos)
if len(yk) < min(hlens[idx][samp_i] for idx in range(self.num_encs)):
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
if rnnlm:
_vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
_score = _vscore.data.cpu().numpy()
ended_hyps[samp_i].append({'yseq': yk, 'vscore': _vscore, 'score': _score})
k = k + 1
# end detection
stop_search = [stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
for samp_i in six.moves.range(batch)]
stop_search_summary = list(set(stop_search))
if len(stop_search_summary) == 1 and stop_search_summary[0]:
break
if rnnlm:
rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
if ctc_scorer[0]:
for idx in range(self.num_encs):
ctc_state[idx] = ctc_scorer[idx].index_select_state(ctc_state[idx], accum_best_ids)
torch.cuda.empty_cache()
dummy_hyps = [{'yseq': [self.sos, self.eos], 'score': np.array([-float('inf')])}]
ended_hyps = [ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
for samp_i in six.moves.range(batch)]
if normalize_score:
for samp_i in six.moves.range(batch):
for x in ended_hyps[samp_i]:
x['score'] /= len(x['yseq'])
nbest_hyps = [sorted(ended_hyps[samp_i], key=lambda x: x['score'],
reverse=True)[:min(len(ended_hyps[samp_i]), recog_args.nbest)]
for samp_i in six.moves.range(batch)]
return nbest_hyps
rnn_forward(self, ey, z_list, c_list, z_prev, c_prev)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for l in six.moves.range(1, self.dlayers):
z_list[l], c_list[l] = self.decoder[l](
self.dropout_dec[l - 1](z_list[l - 1]), (z_prev[l], c_prev[l]))
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for l in six.moves.range(1, self.dlayers):
z_list[l] = self.decoder[l](self.dropout_dec[l - 1](z_list[l - 1]), z_prev[l])
return z_list, c_list
score(self, yseq, state, x)
¶Score new token (required).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y |
torch.Tensor |
1D torch.int64 prefix tokens. |
required |
state |
Scorer state for prefix tokens |
required | |
x |
torch.Tensor |
The encoder feature that generates ys. |
required |
Returns:
Type | Description |
---|---|
tuple[torch.Tensor, Any] |
Tuple of
scores for next token that has a shape of |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode, convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
x[0].unsqueeze(0), [x[0].size(0)],
self.dropout_dec[0](state['z_prev'][0]), state['a_prev'])
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att[idx](x[idx].unsqueeze(0), [x[idx].size(0)],
self.dropout_dec[0](state['z_prev'][0]),
state['a_prev'][idx])
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att[self.num_encs](h_han, [self.num_encs],
self.dropout_dec[0](state['z_prev'][0]),
state['a_prev'][self.num_encs])
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, state['z_prev'], state['c_prev'])
if self.context_residual:
logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return logp, dict(c_prev=c_list[:], z_prev=z_list[:], a_prev=att_w, workspace=(att_idx, z_list, c_list))
zero_state(self, hs_pad)
¶decoder_for(args, odim, sos, eos, att, labeldist)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/decoders.py
def decoder_for(args, odim, sos, eos, att, labeldist):
return Decoder(args.eprojs, odim, args.dtype, args.dlayers, args.dunits, sos, eos, att, args.verbose,
args.char_list, labeldist,
args.lsm_weight, args.sampling_probability, args.dropout_rate_decoder,
getattr(args, "context_residual", False), # use getattr to keep compatibility
getattr(args, "replace_sos", False), # use getattr to keep compatibility
getattr(args, "num_encs", 1)) # use getattr to keep compatibility
encoders
¶
Encoder (Module)
¶Encoder module
:param str etype: type of encoder network :param int idim: number of dimensions of encoder network :param int elayers: number of layers of encoder network :param int eunits: number of lstm units of encoder network :param int eprojs: number of projection units of encoder network :param np.ndarray subsample: list of subsampling numbers :param float dropout: dropout rate :param int in_channel: number of input channels
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
class Encoder(torch.nn.Module):
"""Encoder module
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers: number of layers of encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
"""
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
subsample, dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
else:
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
logging.info(typ.upper() + ' with every-layer projection for encoder')
else:
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
logging.info(typ.upper() + ' without projection for encoder')
def forward(self, xs_pad, ilens, prev_states=None):
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
# make mask to remove bias value in padded part
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
__init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
subsample, dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
else:
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
eprojs,
dropout, typ=typ)])
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
logging.info(typ.upper() + ' with every-layer projection for encoder')
else:
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
logging.info(typ.upper() + ' without projection for encoder')
forward(self, xs_pad, ilens, prev_states=None)
¶Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) :return: batch of hidden state sequences (B, Tmax, eprojs) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def forward(self, xs_pad, ilens, prev_states=None):
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
# make mask to remove bias value in padded part
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
RNN (Module)
¶RNN module
:param int idim: dimension of inputs :param int elayers: number of encoder layers :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) :param int hdim: number of final projection units :param float dropout: dropout rate :param str typ: The RNN type
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
class RNN(torch.nn.Module):
"""RNN module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of final projection units
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
super(RNN, self).__init__()
bidir = typ[0] == "b"
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
bidirectional=bidir)
if bidir:
self.l_last = torch.nn.Linear(cdim * 2, hdim)
else:
self.l_last = torch.nn.Linear(cdim, hdim)
self.typ = typ
def forward(self, xs_pad, ilens, prev_state=None):
"""RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed, it means that we're streaming the input
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
prev_state = reset_backward_rnn_state(prev_state)
ys, states = self.nbrnn(xs_pack, hx=prev_state)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
# (sum _utt frame_utt) x dim
projected = torch.tanh(self.l_last(
ys_pad.contiguous().view(-1, ys_pad.size(2))))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
return xs_pad, ilens, states # x: utt list of frame x dim
__init__(self, idim, elayers, cdim, hdim, dropout, typ='blstm')
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
super(RNN, self).__init__()
bidir = typ[0] == "b"
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
bidirectional=bidir)
if bidir:
self.l_last = torch.nn.Linear(cdim * 2, hdim)
else:
self.l_last = torch.nn.Linear(cdim, hdim)
self.typ = typ
forward(self, xs_pad, ilens, prev_state=None)
¶RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous RNN states :return: batch of hidden state sequences (B, Tmax, eprojs) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def forward(self, xs_pad, ilens, prev_state=None):
"""RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed, it means that we're streaming the input
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
prev_state = reset_backward_rnn_state(prev_state)
ys, states = self.nbrnn(xs_pack, hx=prev_state)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
# (sum _utt frame_utt) x dim
projected = torch.tanh(self.l_last(
ys_pad.contiguous().view(-1, ys_pad.size(2))))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
return xs_pad, ilens, states # x: utt list of frame x dim
RNNP (Module)
¶RNN with projection layer module
:param int idim: dimension of inputs :param int elayers: number of encoder layers :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) :param int hdim: number of projection units :param np.ndarray subsample: list of subsampling numbers :param float dropout: dropout rate :param str typ: The RNN type
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
class RNNP(torch.nn.Module):
"""RNN with projection layer module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of projection units
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
super(RNNP, self).__init__()
bidir = typ[0] == "b"
for i in six.moves.range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = hdim
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
batch_first=True) if "lstm" in typ \
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
# bottleneck layer to merge
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
self.elayers = elayers
self.cdim = cdim
self.subsample = subsample
self.typ = typ
self.bidir = bidir
def forward(self, xs_pad, ilens, prev_state=None):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, hdim)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
elayer_states = []
for layer in six.moves.range(self.elayers):
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
elayer_states.append(states)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
sub = self.subsample[layer + 1]
if sub > 1:
ys_pad = ys_pad[:, ::sub]
ilens = [int(i + 1) // sub for i in ilens]
# (sum _utt frame_utt) x dim
projected = getattr(self, 'bt' + str(layer)
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
if layer == self.elayers - 1:
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
else:
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
__init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ='blstm')
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
super(RNNP, self).__init__()
bidir = typ[0] == "b"
for i in six.moves.range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = hdim
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
batch_first=True) if "lstm" in typ \
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
# bottleneck layer to merge
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
self.elayers = elayers
self.cdim = cdim
self.subsample = subsample
self.typ = typ
self.bidir = bidir
forward(self, xs_pad, ilens, prev_state=None)
¶RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of input sequences (B) :param torch.Tensor prev_state: batch of previous RNN states :return: batch of hidden state sequences (B, Tmax, hdim) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def forward(self, xs_pad, ilens, prev_state=None):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, hdim)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
elayer_states = []
for layer in six.moves.range(self.elayers):
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
elayer_states.append(states)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
sub = self.subsample[layer + 1]
if sub > 1:
ys_pad = ys_pad[:, ::sub]
ilens = [int(i + 1) // sub for i in ilens]
# (sum _utt frame_utt) x dim
projected = getattr(self, 'bt' + str(layer)
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
if layer == self.elayers - 1:
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
else:
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
VGG2L (Module)
¶VGG-like module
:param int in_channel: number of input channels
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
class VGG2L(torch.nn.Module):
"""VGG-like module
:param int in_channel: number of input channels
"""
def __init__(self, in_channel=1):
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
def forward(self, xs_pad, ilens, **kwargs):
"""VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
# x: utt x frame x dim
# xs_pad = F.pad_sequence(xs_pad)
# x: utt x 1 (input channel num) x frame x dim
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
xs_pad.size(2) // self.in_channel).transpose(1, 2)
# NOTE: max_pool1d ?
xs_pad = F.relu(self.conv1_1(xs_pad))
xs_pad = F.relu(self.conv1_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
xs_pad = F.relu(self.conv2_1(xs_pad))
xs_pad = F.relu(self.conv2_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
ilens = np.array(
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad = xs_pad.transpose(1, 2)
xs_pad = xs_pad.contiguous().view(
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
return xs_pad, ilens, None # no state in this layer
__init__(self, in_channel=1)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def __init__(self, in_channel=1):
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
forward(self, xs_pad, ilens, **kwargs)
¶VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) :param torch.Tensor ilens: batch of lengths of input sequences (B) :return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def forward(self, xs_pad, ilens, **kwargs):
"""VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
# x: utt x frame x dim
# xs_pad = F.pad_sequence(xs_pad)
# x: utt x 1 (input channel num) x frame x dim
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
xs_pad.size(2) // self.in_channel).transpose(1, 2)
# NOTE: max_pool1d ?
xs_pad = F.relu(self.conv1_1(xs_pad))
xs_pad = F.relu(self.conv1_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
xs_pad = F.relu(self.conv2_1(xs_pad))
xs_pad = F.relu(self.conv2_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
ilens = np.array(
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad = xs_pad.transpose(1, 2)
xs_pad = xs_pad.contiguous().view(
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
return xs_pad, ilens, None # no state in this layer
encoder_for(args, idim, subsample)
¶Instantiates an encoder module given the program arguments
:param Namespace args: The arguments :param int or List of integer idim: dimension of input, e.g. 83, or List of dimensions of inputs, e.g. [83,83] :param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]] :rtype torch.nn.Module :return: The encoder module
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
def encoder_for(args, idim, subsample):
"""Instantiates an encoder module given the program arguments
:param Namespace args: The arguments
:param int or List of integer idim: dimension of input, e.g. 83, or
List of dimensions of inputs, e.g. [83,83]
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
:rtype torch.nn.Module
:return: The encoder module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
elif num_encs >= 1:
enc_list = torch.nn.ModuleList()
for idx in range(num_encs):
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
args.dropout_rate[idx])
enc_list.append(enc)
return enc_list
else:
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
reset_backward_rnn_state(states)
¶Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/rnn/encoders.py
streaming
special
¶
segment
¶
SegmentStreamingE2E
¶SegmentStreamingE2E constructor.
:param E2E e2e: E2E ASR object :param recog_args: arguments for "recognize" method of E2E
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/segment.py
class SegmentStreamingE2E(object):
"""SegmentStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._blank_idx_in_char_list = -1
for idx in range(len(self._char_list)):
if self._char_list[idx] == self._e2e.blank:
self._blank_idx_in_char_list = idx
break
self._subsampling_factor = np.prod(e2e.subsample)
self._activates = 0
self._blank_dur = 0
self._previous_input = []
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
assert self._recog_args.batchsize <= 1, \
"SegmentStreamingE2E works only with batch size <= 1"
assert "b" not in self._e2e.etype, \
"SegmentStreamingE2E works only with uni-directional encoders"
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
self._previous_input.extend(x)
h, ilen = self._e2e.subsample_frames(x)
# Run encoder and apply greedy search on CTC softmax output
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0),
ilen,
self._previous_encoder_recurrent_state
)
z = self._e2e.ctc.argmax(h).squeeze(0)
if self._activates == 0 and z[0] != self._blank_idx_in_char_list:
self._activates = 1
# Rerun encoder with zero state at onset of detection
tail_len = self._subsampling_factor * (self._recog_args.streaming_onset_margin + 1)
h, ilen = self._e2e.subsample_frames(
np.reshape(self._previous_input[-tail_len:], [-1, len(self._previous_input[0])]))
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, None)
hyp = None
if self._activates == 1:
self._encoder_states.extend(h.squeeze(0))
self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0))
if z[0] == self._blank_idx_in_char_list:
self._blank_dur += 1
else:
self._blank_dur = 0
if self._blank_dur >= self._recog_args.streaming_min_blank_dur:
seg_len = len(self._encoder_states) - self._blank_dur + self._recog_args.streaming_offset_margin
if seg_len > 0:
# Run decoder with a detected segment
h = torch.cat(self._encoder_states[:seg_len], dim=0).view(
-1, self._encoder_states[0].size(0))
if self._recog_args.ctc_weight > 0.0:
lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view(
-1, self._ctc_posteriors[0].size(0))
if self._recog_args.batchsize > 0:
lpz = lpz.unsqueeze(0)
normalize_score = False
else:
lpz = None
normalize_score = True
if self._recog_args.batchsize == 0:
hyp = self._e2e.dec.recognize_beam(
h, lpz, self._recog_args, self._char_list, self._rnnlm)
else:
hlens = torch.tensor([h.shape[0]])
hyp = self._e2e.dec.recognize_beam_batch(
h.unsqueeze(0), hlens, lpz, self._recog_args,
self._char_list, self._rnnlm, normalize_score=normalize_score)[0]
self._activates = 0
self._blank_dur = 0
tail_len = self._subsampling_factor * self._recog_args.streaming_onset_margin
self._previous_input = self._previous_input[-tail_len:]
self._encoder_states = []
self._ctc_posteriors = []
return hyp
__init__(self, e2e, recog_args, rnnlm=None)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/segment.py
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._blank_idx_in_char_list = -1
for idx in range(len(self._char_list)):
if self._char_list[idx] == self._e2e.blank:
self._blank_idx_in_char_list = idx
break
self._subsampling_factor = np.prod(e2e.subsample)
self._activates = 0
self._blank_dur = 0
self._previous_input = []
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
assert self._recog_args.batchsize <= 1, \
"SegmentStreamingE2E works only with batch size <= 1"
assert "b" not in self._e2e.etype, \
"SegmentStreamingE2E works only with uni-directional encoders"
accept_input(self, x)
¶Call this method each time a new batch of input is available.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/segment.py
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
self._previous_input.extend(x)
h, ilen = self._e2e.subsample_frames(x)
# Run encoder and apply greedy search on CTC softmax output
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0),
ilen,
self._previous_encoder_recurrent_state
)
z = self._e2e.ctc.argmax(h).squeeze(0)
if self._activates == 0 and z[0] != self._blank_idx_in_char_list:
self._activates = 1
# Rerun encoder with zero state at onset of detection
tail_len = self._subsampling_factor * (self._recog_args.streaming_onset_margin + 1)
h, ilen = self._e2e.subsample_frames(
np.reshape(self._previous_input[-tail_len:], [-1, len(self._previous_input[0])]))
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, None)
hyp = None
if self._activates == 1:
self._encoder_states.extend(h.squeeze(0))
self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0))
if z[0] == self._blank_idx_in_char_list:
self._blank_dur += 1
else:
self._blank_dur = 0
if self._blank_dur >= self._recog_args.streaming_min_blank_dur:
seg_len = len(self._encoder_states) - self._blank_dur + self._recog_args.streaming_offset_margin
if seg_len > 0:
# Run decoder with a detected segment
h = torch.cat(self._encoder_states[:seg_len], dim=0).view(
-1, self._encoder_states[0].size(0))
if self._recog_args.ctc_weight > 0.0:
lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view(
-1, self._ctc_posteriors[0].size(0))
if self._recog_args.batchsize > 0:
lpz = lpz.unsqueeze(0)
normalize_score = False
else:
lpz = None
normalize_score = True
if self._recog_args.batchsize == 0:
hyp = self._e2e.dec.recognize_beam(
h, lpz, self._recog_args, self._char_list, self._rnnlm)
else:
hlens = torch.tensor([h.shape[0]])
hyp = self._e2e.dec.recognize_beam_batch(
h.unsqueeze(0), hlens, lpz, self._recog_args,
self._char_list, self._rnnlm, normalize_score=normalize_score)[0]
self._activates = 0
self._blank_dur = 0
tail_len = self._subsampling_factor * self._recog_args.streaming_onset_margin
self._previous_input = self._previous_input[-tail_len:]
self._encoder_states = []
self._ctc_posteriors = []
return hyp
window
¶
WindowStreamingE2E
¶WindowStreamingE2E constructor.
:param E2E e2e: E2E ASR object :param recog_args: arguments for "recognize" method of E2E
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/window.py
class WindowStreamingE2E(object):
"""WindowStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._offset = 0
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
self._last_recognition = None
assert self._recog_args.ctc_weight > 0.0, \
"WindowStreamingE2E works only with combined CTC and attention decoders."
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
h, ilen = self._e2e.subsample_frames(x)
# Streaming encoder
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0),
ilen,
self._previous_encoder_recurrent_state
)
self._encoder_states.append(h.squeeze(0))
# CTC posteriors for the incoming audio
self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0))
def _input_window_for_decoder(self, use_all=False):
if use_all:
return torch.cat(self._encoder_states, dim=0), torch.cat(self._ctc_posteriors, dim=0)
def select_unprocessed_windows(window_tensors):
last_offset = self._offset
offset_traversed = 0
selected_windows = []
for es in window_tensors:
if offset_traversed > last_offset:
selected_windows.append(es)
continue
offset_traversed += es.size(1)
return torch.cat(selected_windows, dim=0)
return (
select_unprocessed_windows(self._encoder_states),
select_unprocessed_windows(self._ctc_posteriors)
)
def decode_with_attention_offline(self):
"""Run the attention decoder offline.
Works even if the previous layers (encoder and CTC decoder) were being run in the online mode.
This method should be run after all the audio has been consumed.
This is used mostly to compare the results between offline and online implementation of the previous layers.
"""
h, lpz = self._input_window_for_decoder(use_all=True)
return self._e2e.dec.recognize_beam(h, lpz, self._recog_args, self._char_list, self._rnnlm)
__init__(self, e2e, recog_args, rnnlm=None)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/window.py
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._offset = 0
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
self._last_recognition = None
assert self._recog_args.ctc_weight > 0.0, \
"WindowStreamingE2E works only with combined CTC and attention decoders."
accept_input(self, x)
¶Call this method each time a new batch of input is available.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/window.py
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
h, ilen = self._e2e.subsample_frames(x)
# Streaming encoder
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0),
ilen,
self._previous_encoder_recurrent_state
)
self._encoder_states.append(h.squeeze(0))
# CTC posteriors for the incoming audio
self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0))
decode_with_attention_offline(self)
¶Run the attention decoder offline.
Works even if the previous layers (encoder and CTC decoder) were being run in the online mode. This method should be run after all the audio has been consumed. This is used mostly to compare the results between offline and online implementation of the previous layers.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/streaming/window.py
def decode_with_attention_offline(self):
"""Run the attention decoder offline.
Works even if the previous layers (encoder and CTC decoder) were being run in the online mode.
This method should be run after all the audio has been consumed.
This is used mostly to compare the results between offline and online implementation of the previous layers.
"""
h, lpz = self._input_window_for_decoder(use_all=True)
return self._e2e.dec.recognize_beam(h, lpz, self._recog_args, self._char_list, self._rnnlm)
tacotron2
special
¶
cbhg
¶
CBHG related modules.
CBHG (Module)
¶CBHG module to convert log Mel-filterbanks to linear spectrogram.
This is a module of CBHG introduced in Tacotron: Towards End-to-End Speech Synthesis
_.
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.
.. _Tacotron: Towards End-to-End Speech Synthesis
: https://arxiv.org/abs/1703.10135
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
class CBHG(torch.nn.Module):
"""CBHG module to convert log Mel-filterbanks to linear spectrogram.
This is a module of CBHG introduced in `Tacotron: Towards End-to-End Speech Synthesis`_.
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.
.. _`Tacotron: Towards End-to-End Speech Synthesis`: https://arxiv.org/abs/1703.10135
"""
def __init__(self,
idim,
odim,
conv_bank_layers=8,
conv_bank_chans=128,
conv_proj_filts=3,
conv_proj_chans=256,
highway_layers=4,
highway_units=128,
gru_units=256):
"""Initialize CBHG module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
conv_bank_layers (int, optional): The number of convolution bank layers.
conv_bank_chans (int, optional): The number of channels in convolution bank.
conv_proj_filts (int, optional): Kernel size of convolutional projection layer.
conv_proj_chans (int, optional): The number of channels in convolutional projection layer.
highway_layers (int, optional): The number of highway network layers.
highway_units (int, optional): The number of highway network units.
gru_units (int, optional): The number of GRU units (for both directions).
"""
super(CBHG, self).__init__()
self.idim = idim
self.odim = odim
self.conv_bank_layers = conv_bank_layers
self.conv_bank_chans = conv_bank_chans
self.conv_proj_filts = conv_proj_filts
self.conv_proj_chans = conv_proj_chans
self.highway_layers = highway_layers
self.highway_units = highway_units
self.gru_units = gru_units
# define 1d convolution bank
self.conv_bank = torch.nn.ModuleList()
for k in range(1, self.conv_bank_layers + 1):
if k % 2 != 0:
padding = (k - 1) // 2
else:
padding = ((k - 1) // 2, (k - 1) // 2 + 1)
self.conv_bank += [torch.nn.Sequential(
torch.nn.ConstantPad1d(padding, 0.0),
torch.nn.Conv1d(idim, self.conv_bank_chans, k, stride=1,
padding=0, bias=True),
torch.nn.BatchNorm1d(self.conv_bank_chans),
torch.nn.ReLU())]
# define max pooling (need padding for one-side to keep same length)
self.max_pool = torch.nn.Sequential(
torch.nn.ConstantPad1d((0, 1), 0.0),
torch.nn.MaxPool1d(2, stride=1))
# define 1d convolution projection
self.projections = torch.nn.Sequential(
torch.nn.Conv1d(self.conv_bank_chans * self.conv_bank_layers, self.conv_proj_chans,
self.conv_proj_filts, stride=1,
padding=(self.conv_proj_filts - 1) // 2, bias=True),
torch.nn.BatchNorm1d(self.conv_proj_chans),
torch.nn.ReLU(),
torch.nn.Conv1d(self.conv_proj_chans, self.idim,
self.conv_proj_filts, stride=1,
padding=(self.conv_proj_filts - 1) // 2, bias=True),
torch.nn.BatchNorm1d(self.idim),
)
# define highway network
self.highways = torch.nn.ModuleList()
self.highways += [torch.nn.Linear(idim, self.highway_units)]
for _ in range(self.highway_layers):
self.highways += [HighwayNet(self.highway_units)]
# define bidirectional GRU
self.gru = torch.nn.GRU(self.highway_units, gru_units // 2, num_layers=1,
batch_first=True, bidirectional=True)
# define final projection
self.output = torch.nn.Linear(gru_units, odim, bias=True)
def forward(self, xs, ilens):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
ilens (LongTensor): Batch of lengths of each input sequence (B,).
Return:
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
LongTensor: Batch of lengths of each output sequence (B,).
"""
xs = xs.transpose(1, 2) # (B, idim, Tmax)
convs = []
for k in range(self.conv_bank_layers):
convs += [self.conv_bank[k](xs)]
convs = torch.cat(convs, dim=1) # (B, #CH * #BANK, Tmax)
convs = self.max_pool(convs)
convs = self.projections(convs).transpose(1, 2) # (B, Tmax, idim)
xs = xs.transpose(1, 2) + convs
# + 1 for dimension adjustment layer
for l in range(self.highway_layers + 1):
xs = self.highways[l](xs)
# sort by length
xs, ilens, sort_idx = self._sort_by_length(xs, ilens)
# total_length needs for DataParallel
# (see https://github.com/pytorch/pytorch/pull/6327)
total_length = xs.size(1)
xs = pack_padded_sequence(xs, ilens, batch_first=True)
self.gru.flatten_parameters()
xs, _ = self.gru(xs)
xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)
# revert sorting by length
xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)
xs = self.output(xs) # (B, Tmax, odim)
return xs, ilens
def inference(self, x):
"""Inference.
Args:
x (Tensor): The sequences of inputs (T, idim).
Return:
Tensor: The sequence of outputs (T, odim).
"""
assert len(x.size()) == 2
xs = x.unsqueeze(0)
ilens = x.new([x.size(0)]).long()
return self.forward(xs, ilens)[0][0]
def _sort_by_length(self, xs, ilens):
sort_ilens, sort_idx = ilens.sort(0, descending=True)
return xs[sort_idx], ilens[sort_idx], sort_idx
def _revert_sort_by_length(self, xs, ilens, sort_idx):
_, revert_idx = sort_idx.sort(0)
return xs[revert_idx], ilens[revert_idx]
__init__(self, idim, odim, conv_bank_layers=8, conv_bank_chans=128, conv_proj_filts=3, conv_proj_chans=256, highway_layers=4, highway_units=128, gru_units=256)
special
¶Initialize CBHG module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
conv_bank_layers |
int |
The number of convolution bank layers. |
8 |
conv_bank_chans |
int |
The number of channels in convolution bank. |
128 |
conv_proj_filts |
int |
Kernel size of convolutional projection layer. |
3 |
conv_proj_chans |
int |
The number of channels in convolutional projection layer. |
256 |
highway_layers |
int |
The number of highway network layers. |
4 |
highway_units |
int |
The number of highway network units. |
128 |
gru_units |
int |
The number of GRU units (for both directions). |
256 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
def __init__(self,
idim,
odim,
conv_bank_layers=8,
conv_bank_chans=128,
conv_proj_filts=3,
conv_proj_chans=256,
highway_layers=4,
highway_units=128,
gru_units=256):
"""Initialize CBHG module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
conv_bank_layers (int, optional): The number of convolution bank layers.
conv_bank_chans (int, optional): The number of channels in convolution bank.
conv_proj_filts (int, optional): Kernel size of convolutional projection layer.
conv_proj_chans (int, optional): The number of channels in convolutional projection layer.
highway_layers (int, optional): The number of highway network layers.
highway_units (int, optional): The number of highway network units.
gru_units (int, optional): The number of GRU units (for both directions).
"""
super(CBHG, self).__init__()
self.idim = idim
self.odim = odim
self.conv_bank_layers = conv_bank_layers
self.conv_bank_chans = conv_bank_chans
self.conv_proj_filts = conv_proj_filts
self.conv_proj_chans = conv_proj_chans
self.highway_layers = highway_layers
self.highway_units = highway_units
self.gru_units = gru_units
# define 1d convolution bank
self.conv_bank = torch.nn.ModuleList()
for k in range(1, self.conv_bank_layers + 1):
if k % 2 != 0:
padding = (k - 1) // 2
else:
padding = ((k - 1) // 2, (k - 1) // 2 + 1)
self.conv_bank += [torch.nn.Sequential(
torch.nn.ConstantPad1d(padding, 0.0),
torch.nn.Conv1d(idim, self.conv_bank_chans, k, stride=1,
padding=0, bias=True),
torch.nn.BatchNorm1d(self.conv_bank_chans),
torch.nn.ReLU())]
# define max pooling (need padding for one-side to keep same length)
self.max_pool = torch.nn.Sequential(
torch.nn.ConstantPad1d((0, 1), 0.0),
torch.nn.MaxPool1d(2, stride=1))
# define 1d convolution projection
self.projections = torch.nn.Sequential(
torch.nn.Conv1d(self.conv_bank_chans * self.conv_bank_layers, self.conv_proj_chans,
self.conv_proj_filts, stride=1,
padding=(self.conv_proj_filts - 1) // 2, bias=True),
torch.nn.BatchNorm1d(self.conv_proj_chans),
torch.nn.ReLU(),
torch.nn.Conv1d(self.conv_proj_chans, self.idim,
self.conv_proj_filts, stride=1,
padding=(self.conv_proj_filts - 1) // 2, bias=True),
torch.nn.BatchNorm1d(self.idim),
)
# define highway network
self.highways = torch.nn.ModuleList()
self.highways += [torch.nn.Linear(idim, self.highway_units)]
for _ in range(self.highway_layers):
self.highways += [HighwayNet(self.highway_units)]
# define bidirectional GRU
self.gru = torch.nn.GRU(self.highway_units, gru_units // 2, num_layers=1,
batch_first=True, bidirectional=True)
# define final projection
self.output = torch.nn.Linear(gru_units, odim, bias=True)
forward(self, xs, ilens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of the padded sequences of inputs (B, Tmax, idim). |
required |
ilens |
LongTensor |
Batch of lengths of each input sequence (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of the padded sequence of outputs (B, Tmax, odim). LongTensor: Batch of lengths of each output sequence (B,). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
def forward(self, xs, ilens):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
ilens (LongTensor): Batch of lengths of each input sequence (B,).
Return:
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
LongTensor: Batch of lengths of each output sequence (B,).
"""
xs = xs.transpose(1, 2) # (B, idim, Tmax)
convs = []
for k in range(self.conv_bank_layers):
convs += [self.conv_bank[k](xs)]
convs = torch.cat(convs, dim=1) # (B, #CH * #BANK, Tmax)
convs = self.max_pool(convs)
convs = self.projections(convs).transpose(1, 2) # (B, Tmax, idim)
xs = xs.transpose(1, 2) + convs
# + 1 for dimension adjustment layer
for l in range(self.highway_layers + 1):
xs = self.highways[l](xs)
# sort by length
xs, ilens, sort_idx = self._sort_by_length(xs, ilens)
# total_length needs for DataParallel
# (see https://github.com/pytorch/pytorch/pull/6327)
total_length = xs.size(1)
xs = pack_padded_sequence(xs, ilens, batch_first=True)
self.gru.flatten_parameters()
xs, _ = self.gru(xs)
xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)
# revert sorting by length
xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)
xs = self.output(xs) # (B, Tmax, odim)
return xs, ilens
inference(self, x)
¶Inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
The sequences of inputs (T, idim). |
required |
Returns:
Type | Description |
---|---|
Tensor |
The sequence of outputs (T, odim). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
CBHGLoss (Module)
¶Loss function module for CBHG.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
class CBHGLoss(torch.nn.Module):
"""Loss function module for CBHG."""
def __init__(self, use_masking=True):
"""Initialize CBHG loss module.
Args:
use_masking (bool): Whether to mask padded part in loss calculation.
"""
super(CBHGLoss, self).__init__()
self.use_masking = use_masking
def forward(self, cbhg_outs, spcs, olens):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
spcs = spcs.masked_select(mask)
cbhg_outs = cbhg_outs.masked_select(mask)
# calculate loss
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
return cbhg_l1_loss, cbhg_mse_loss
__init__(self, use_masking=True)
special
¶Initialize CBHG loss module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
use_masking |
bool |
Whether to mask padded part in loss calculation. |
True |
forward(self, cbhg_outs, spcs, olens)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cbhg_outs |
Tensor |
Batch of CBHG outputs (B, Lmax, spc_dim). |
required |
spcs |
Tensor |
Batch of groundtruth of spectrogram (B, Lmax, spc_dim). |
required |
olens |
LongTensor |
Batch of the lengths of each sequence (B,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
L1 loss value Tensor: Mean square error loss value. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
def forward(self, cbhg_outs, spcs, olens):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
spcs = spcs.masked_select(mask)
cbhg_outs = cbhg_outs.masked_select(mask)
# calculate loss
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
return cbhg_l1_loss, cbhg_mse_loss
HighwayNet (Module)
¶Highway Network module.
This is a module of Highway Network introduced in Highway Networks
_.
.. _Highway Networks
: https://arxiv.org/abs/1505.00387
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
class HighwayNet(torch.nn.Module):
"""Highway Network module.
This is a module of Highway Network introduced in `Highway Networks`_.
.. _`Highway Networks`: https://arxiv.org/abs/1505.00387
"""
def __init__(self, idim):
"""Initialize Highway Network module.
Args:
idim (int): Dimension of the inputs.
"""
super(HighwayNet, self).__init__()
self.idim = idim
self.projection = torch.nn.Sequential(
torch.nn.Linear(idim, idim),
torch.nn.ReLU())
self.gate = torch.nn.Sequential(
torch.nn.Linear(idim, idim),
torch.nn.Sigmoid())
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of inputs (B, ..., idim).
Returns:
Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).
"""
proj = self.projection(x)
gate = self.gate(x)
return proj * gate + x * (1.0 - gate)
__init__(self, idim)
special
¶Initialize Highway Network module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
def __init__(self, idim):
"""Initialize Highway Network module.
Args:
idim (int): Dimension of the inputs.
"""
super(HighwayNet, self).__init__()
self.idim = idim
self.projection = torch.nn.Sequential(
torch.nn.Linear(idim, idim),
torch.nn.ReLU())
self.gate = torch.nn.Sequential(
torch.nn.Linear(idim, idim),
torch.nn.Sigmoid())
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Batch of inputs (B, ..., idim). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of outputs, which are the same shape as inputs (B, ..., idim). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/cbhg.py
decoder
¶
Tacotron2 decoder related modules.
Decoder (Module)
¶Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
_. The decoder generates the sequence of
features from the sequence of the hidden states.
.. _Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
:
https://arxiv.org/abs/1712.05884
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
class Decoder(torch.nn.Module):
"""Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The decoder generates the sequence of
features from the sequence of the hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, odim, att,
dlayers=2,
dunits=1024,
prenet_layers=2,
prenet_units=256,
postnet_layers=5,
postnet_chans=512,
postnet_filts=5,
output_activation_fn=None,
cumulate_att_w=True,
use_batch_norm=True,
use_concate=True,
dropout_rate=0.5,
zoneout_rate=0.1,
reduction_factor=1):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional): Activation function for outputs.
cumulate_att_w (bool, optional): Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super(Decoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.odim = odim
self.att = att
self.output_activation_fn = output_activation_fn
self.cumulate_att_w = cumulate_att_w
self.use_concate = use_concate
self.reduction_factor = reduction_factor
# check attention type
if isinstance(self.att, AttForwardTA):
self.use_att_extra_inputs = True
else:
self.use_att_extra_inputs = False
# define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = torch.nn.ModuleList()
for layer in six.moves.range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits
lstm = torch.nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0:
lstm = ZoneOutCell(lstm, zoneout_rate)
self.lstm += [lstm]
# define prenet
if prenet_layers > 0:
self.prenet = Prenet(
idim=odim,
n_layers=prenet_layers,
n_units=prenet_units,
dropout_rate=dropout_rate)
else:
self.prenet = None
# define postnet
if postnet_layers > 0:
self.postnet = Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=dropout_rate)
else:
self.postnet = None
# define projection layers
iunits = idim + dunits if use_concate else dunits
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False)
self.prob_out = torch.nn.Linear(iunits, reduction_factor)
# initialize
self.apply(decoder_init)
def _zero_state(self, hs):
init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
return init_hs
def forward(self, hs, hlens, ys):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
outs, logits, att_ws = [], [], []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)]
logits += [self.prob_out(zcs)]
att_ws += [att_w]
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
logits = torch.cat(logits, dim=1) # (B, Lmax)
before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax)
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
if self.reduction_factor > 1:
before_outs = before_outs.view(before_outs.size(0), self.odim, -1) # (B, odim, Lmax)
if self.postnet is not None:
after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax)
else:
after_outs = before_outs
before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim)
after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim)
logits = logits
# apply activation function for scaling
if self.output_activation_fn is not None:
before_outs = self.output_activation_fn(before_outs)
after_outs = self.output_activation_fn(after_outs)
return after_outs, before_outs, logits, att_ws
def inference(self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0,
use_att_constraint=False, backward_window=None, forward_window=None):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio. If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio. If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
use_att_constraint (bool): Whether to apply attention constraint introduced in `Deep Voice 3`_.
backward_window (int): Backward window size in attention constraint.
forward_window (int): Forward window size in attention constraint.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
"""
# setup
assert len(h.size()) == 2
hs = h.unsqueeze(0)
ilens = [h.size(0)]
maxlen = int(h.size(0) * maxlenratio)
minlen = int(h.size(0) * minlenratio)
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(1, self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# setup for attention constraint
if use_att_constraint:
last_attended_idx = 0
else:
last_attended_idx = None
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
while True:
# updated index
idx += self.reduction_factor
# decoder calculation
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w, prev_out,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window)
else:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...]
probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...]
if self.output_activation_fn is not None:
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim)
else:
prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
if use_att_constraint:
last_attended_idx = int(att_w.argmax())
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=2) # (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
att_ws = torch.cat(att_ws, dim=0)
break
if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs)
return outs, probs, att_ws
def calculate_all_attentions(self, hs, hlens, ys):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
att_ws = []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
return att_ws
__init__(self, idim, odim, att, dlayers=2, dunits=1024, prenet_layers=2, prenet_units=256, postnet_layers=5, postnet_chans=512, postnet_filts=5, output_activation_fn=None, cumulate_att_w=True, use_batch_norm=True, use_concate=True, dropout_rate=0.5, zoneout_rate=0.1, reduction_factor=1)
special
¶Initialize Tacotron2 decoder module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
att |
torch.nn.Module |
Instance of attention class. |
required |
dlayers |
int |
The number of decoder lstm layers. |
2 |
dunits |
int |
The number of decoder lstm units. |
1024 |
prenet_layers |
int |
The number of prenet layers. |
2 |
prenet_units |
int |
The number of prenet units. |
256 |
postnet_layers |
int |
The number of postnet layers. |
5 |
postnet_filts |
int |
The number of postnet filter size. |
5 |
postnet_chans |
int |
The number of postnet filter channels. |
512 |
output_activation_fn |
torch.nn.Module |
Activation function for outputs. |
None |
cumulate_att_w |
bool |
Whether to cumulate previous attention weight. |
True |
use_batch_norm |
bool |
Whether to use batch normalization. |
True |
use_concate |
bool |
Whether to concatenate encoder embedding with decoder lstm outputs. |
True |
dropout_rate |
float |
Dropout rate. |
0.5 |
zoneout_rate |
float |
Zoneout rate. |
0.1 |
reduction_factor |
int |
Reduction factor. |
1 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def __init__(self, idim, odim, att,
dlayers=2,
dunits=1024,
prenet_layers=2,
prenet_units=256,
postnet_layers=5,
postnet_chans=512,
postnet_filts=5,
output_activation_fn=None,
cumulate_att_w=True,
use_batch_norm=True,
use_concate=True,
dropout_rate=0.5,
zoneout_rate=0.1,
reduction_factor=1):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional): Activation function for outputs.
cumulate_att_w (bool, optional): Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super(Decoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.odim = odim
self.att = att
self.output_activation_fn = output_activation_fn
self.cumulate_att_w = cumulate_att_w
self.use_concate = use_concate
self.reduction_factor = reduction_factor
# check attention type
if isinstance(self.att, AttForwardTA):
self.use_att_extra_inputs = True
else:
self.use_att_extra_inputs = False
# define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = torch.nn.ModuleList()
for layer in six.moves.range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits
lstm = torch.nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0:
lstm = ZoneOutCell(lstm, zoneout_rate)
self.lstm += [lstm]
# define prenet
if prenet_layers > 0:
self.prenet = Prenet(
idim=odim,
n_layers=prenet_layers,
n_units=prenet_units,
dropout_rate=dropout_rate)
else:
self.prenet = None
# define postnet
if postnet_layers > 0:
self.postnet = Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=dropout_rate)
else:
self.postnet = None
# define projection layers
iunits = idim + dunits if use_concate else dunits
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False)
self.prob_out = torch.nn.Linear(iunits, reduction_factor)
# initialize
self.apply(decoder_init)
calculate_all_attentions(self, hs, hlens, ys)
¶Calculate all of the attention weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hs |
Tensor |
Batch of the sequences of padded hidden states (B, Tmax, idim). |
required |
hlens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of the sequences of padded target features (B, Lmax, odim). |
required |
Returns:
Type | Description |
---|---|
numpy.ndarray |
Batch of attention weights (B, Lmax, Tmax). |
Note
This computation is performed in teacher-forcing manner.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def calculate_all_attentions(self, hs, hlens, ys):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
att_ws = []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
return att_ws
forward(self, hs, hlens, ys)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hs |
Tensor |
Batch of the sequences of padded hidden states (B, Tmax, idim). |
required |
hlens |
LongTensor |
Batch of lengths of each input batch (B,). |
required |
ys |
Tensor |
Batch of the sequences of padded target features (B, Lmax, odim). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of output tensors after postnet (B, Lmax, odim). Tensor: Batch of output tensors before postnet (B, Lmax, odim). Tensor: Batch of logits of stop prediction (B, Lmax). Tensor: Batch of attention weights (B, Lmax, Tmax). |
Note
This computation is performed in teacher-forcing manner.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def forward(self, hs, hlens, ys):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor): Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1::self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
outs, logits, att_ws = [], [], []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)]
logits += [self.prob_out(zcs)]
att_ws += [att_w]
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
logits = torch.cat(logits, dim=1) # (B, Lmax)
before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax)
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
if self.reduction_factor > 1:
before_outs = before_outs.view(before_outs.size(0), self.odim, -1) # (B, odim, Lmax)
if self.postnet is not None:
after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax)
else:
after_outs = before_outs
before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim)
after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim)
logits = logits
# apply activation function for scaling
if self.output_activation_fn is not None:
before_outs = self.output_activation_fn(before_outs)
after_outs = self.output_activation_fn(after_outs)
return after_outs, before_outs, logits, att_ws
inference(self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0, use_att_constraint=False, backward_window=None, forward_window=None)
¶Generate the sequence of features given the sequences of characters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
h |
Tensor |
Input sequence of encoder hidden states (T, C). |
required |
threshold |
float |
Threshold to stop generation. |
0.5 |
minlenratio |
float |
Minimum length ratio. If set to 1.0 and the length of input is 10, the minimum length of outputs will be 10 * 1 = 10. |
0.0 |
minlenratio |
float |
Minimum length ratio. If set to 10 and the length of input is 10, the maximum length of outputs will be 10 * 10 = 100. |
0.0 |
use_att_constraint |
bool |
Whether to apply attention constraint introduced in |
False |
backward_window |
int |
Backward window size in attention constraint. |
None |
forward_window |
int |
Forward window size in attention constraint. |
None |
Returns:
Type | Description |
---|---|
Tensor |
Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). |
Note
This computation is performed in auto-regressive manner.
.. _Deep Voice 3
: https://arxiv.org/abs/1710.07654
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def inference(self, h, threshold=0.5, minlenratio=0.0, maxlenratio=10.0,
use_att_constraint=False, backward_window=None, forward_window=None):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio. If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio. If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
use_att_constraint (bool): Whether to apply attention constraint introduced in `Deep Voice 3`_.
backward_window (int): Backward window size in attention constraint.
forward_window (int): Forward window size in attention constraint.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
"""
# setup
assert len(h.size()) == 2
hs = h.unsqueeze(0)
ilens = [h.size(0)]
maxlen = int(h.size(0) * maxlenratio)
minlen = int(h.size(0) * minlenratio)
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in six.moves.range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(1, self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# setup for attention constraint
if use_att_constraint:
last_attended_idx = 0
else:
last_attended_idx = None
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
while True:
# updated index
idx += self.reduction_factor
# decoder calculation
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w, prev_out,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window)
else:
att_c, att_w = self.att(hs, ilens, z_list[0], prev_att_w,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for l in six.moves.range(1, len(self.lstm)):
z_list[l], c_list[l] = self.lstm[l](
z_list[l - 1], (z_list[l], c_list[l]))
zcs = torch.cat([z_list[-1], att_c], dim=1) if self.use_concate else z_list[-1]
outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...]
probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...]
if self.output_activation_fn is not None:
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim)
else:
prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
if use_att_constraint:
last_attended_idx = int(att_w.argmax())
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=2) # (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
att_ws = torch.cat(att_ws, dim=0)
break
if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs)
return outs, probs, att_ws
Postnet (Module)
¶Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network, which described in Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions
_. The Postnet predicts refines the predicted
Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram.
.. _Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
:
https://arxiv.org/abs/1712.05884
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
class Postnet(torch.nn.Module):
"""Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network, which described in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_. The Postnet predicts refines the predicted
Mel-filterbank of the decoder, which helps to compensate the detail sturcture of spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super(Postnet, self).__init__()
self.postnet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers - 1):
ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(ochans),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
ichans = n_chans if n_layers != 1 else odim
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(odim),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Dropout(dropout_rate))]
def forward(self, xs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
Returns:
Tensor: Batch of padded output tensor. (B, odim, Tmax).
"""
for l in six.moves.range(len(self.postnet)):
xs = self.postnet[l](xs)
return xs
__init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True)
special
¶Initialize postnet module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
n_layers |
int |
The number of layers. |
5 |
n_filts |
int |
The number of filter size. |
5 |
n_units |
int |
The number of filter channels. |
required |
use_batch_norm |
bool |
Whether to use batch normalization.. |
True |
dropout_rate |
float |
Dropout rate.. |
0.5 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def __init__(self, idim, odim, n_layers=5, n_chans=512, n_filts=5, dropout_rate=0.5, use_batch_norm=True):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super(Postnet, self).__init__()
self.postnet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers - 1):
ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(ochans),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, ochans, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate))]
ichans = n_chans if n_layers != 1 else odim
if use_batch_norm:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(odim),
torch.nn.Dropout(dropout_rate))]
else:
self.postnet += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, odim, n_filts, stride=1,
padding=(n_filts - 1) // 2, bias=False),
torch.nn.Dropout(dropout_rate))]
forward(self, xs)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of the sequences of padded input tensors (B, idim, Tmax). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of padded output tensor. (B, odim, Tmax). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
Prenet (Module)
¶Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network, which described in Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
_. The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions.
Note
This module alway applies dropout even in evaluation. See the detail in Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions
_.
.. _Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
:
https://arxiv.org/abs/1712.05884
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
class Prenet(torch.nn.Module):
"""Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network, which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm, which helps to learn diagonal attentions.
Note:
This module alway applies dropout even in evaluation. See the detail in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super(Prenet, self).__init__()
self.dropout_rate = dropout_rate
self.prenet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers):
n_inputs = idim if layer == 0 else n_units
self.prenet += [torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_units),
torch.nn.ReLU())]
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., idim).
Returns:
Tensor: Batch of output tensors (B, ..., odim).
"""
for l in six.moves.range(len(self.prenet)):
x = F.dropout(self.prenet[l](x), self.dropout_rate)
return x
__init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5)
special
¶Initialize prenet module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idim |
int |
Dimension of the inputs. |
required |
odim |
int |
Dimension of the outputs. |
required |
n_layers |
int |
The number of prenet layers. |
2 |
n_units |
int |
The number of prenet units. |
256 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super(Prenet, self).__init__()
self.dropout_rate = dropout_rate
self.prenet = torch.nn.ModuleList()
for layer in six.moves.range(n_layers):
n_inputs = idim if layer == 0 else n_units
self.prenet += [torch.nn.Sequential(
torch.nn.Linear(n_inputs, n_units),
torch.nn.ReLU())]
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Batch of input tensors (B, ..., idim). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of output tensors (B, ..., odim). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
ZoneOutCell (Module)
¶ZoneOut Cell module.
This is a module of zoneout described in Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
.
This code is modified from eladhoffer/seq2seq.pytorch
.
Examples:
.. _Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations
:
https://arxiv.org/abs/1606.01305
.. _eladhoffer/seq2seq.pytorch
:
https://github.com/eladhoffer/seq2seq.pytorch
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
class ZoneOutCell(torch.nn.Module):
"""ZoneOut Cell module.
This is a module of zoneout described in `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_.
This code is modified from `eladhoffer/seq2seq.pytorch`_.
Examples:
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`:
https://arxiv.org/abs/1606.01305
.. _`eladhoffer/seq2seq.pytorch`:
https://github.com/eladhoffer/seq2seq.pytorch
"""
def __init__(self, cell, zoneout_rate=0.1):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super(ZoneOutCell, self).__init__()
self.cell = cell
self.hidden_size = cell.hidden_size
self.zoneout_rate = zoneout_rate
if zoneout_rate > 1.0 or zoneout_rate < 0.0:
raise ValueError("zoneout probability must be in the range from 0.0 to 1.0.")
def forward(self, inputs, hidden):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden = self.cell(inputs, hidden)
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate)
return next_hidden
def _zoneout(self, h, next_h, prob):
# apply recursively
if isinstance(h, tuple):
num_h = len(h)
if not isinstance(prob, tuple):
prob = tuple([prob] * num_h)
return tuple([self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)])
if self.training:
mask = h.new(*h.size()).bernoulli_(prob)
return mask * h + (1 - mask) * next_h
else:
return prob * h + (1 - prob) * next_h
__init__(self, cell, zoneout_rate=0.1)
special
¶Initialize zone out cell module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cell |
torch.nn.Module |
Pytorch recurrent cell module e.g. |
required |
zoneout_rate |
float |
Probability of zoneout from 0.0 to 1.0. |
0.1 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def __init__(self, cell, zoneout_rate=0.1):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super(ZoneOutCell, self).__init__()
self.cell = cell
self.hidden_size = cell.hidden_size
self.zoneout_rate = zoneout_rate
if zoneout_rate > 1.0 or zoneout_rate < 0.0:
raise ValueError("zoneout probability must be in the range from 0.0 to 1.0.")
forward(self, inputs, hidden)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Tensor |
Batch of input tensor (B, input_size). |
required |
hidden |
tuple |
|
required |
Returns:
Type | Description |
---|---|
tuple |
|
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/decoder.py
def forward(self, inputs, hidden):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden = self.cell(inputs, hidden)
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate)
return next_hidden
decoder_init(m)
¶
encoder
¶
Tacotron2 encoder related modules.
Encoder (Module)
¶Encoder module of Spectrogram prediction network.
This is a module of encoder of Spectrogram prediction network in Tacotron2, which described in Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
_. This is the encoder which converts the
sequence of characters into the sequence of hidden states.
.. _Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
:
https://arxiv.org/abs/1712.05884
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/encoder.py
class Encoder(torch.nn.Module):
"""Encoder module of Spectrogram prediction network.
This is a module of encoder of Spectrogram prediction network in Tacotron2, which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. This is the encoder which converts the
sequence of characters into the sequence of hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim,
embed_dim=512,
elayers=1,
eunits=512,
econv_layers=3,
econv_chans=512,
econv_filts=5,
use_batch_norm=True,
use_residual=False,
dropout_rate=0.5,
padding_idx=0):
"""Initialize Tacotron2 encoder module.
Args:
idim (int) Dimension of the inputs.
embed_dim (int, optional) Dimension of character embedding.
elayers (int, optional) The number of encoder blstm layers.
eunits (int, optional) The number of encoder blstm units.
econv_layers (int, optional) The number of encoder conv layers.
econv_filts (int, optional) The number of encoder conv filter size.
econv_chans (int, optional) The number of encoder conv filter channels.
use_batch_norm (bool, optional) Whether to use batch normalization.
use_residual (bool, optional) Whether to use residual connection.
dropout_rate (float, optional) Dropout rate.
"""
super(Encoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.use_residual = use_residual
# define network layer modules
self.embed = torch.nn.Embedding(idim, embed_dim, padding_idx=padding_idx)
if econv_layers > 0:
self.convs = torch.nn.ModuleList()
for layer in six.moves.range(econv_layers):
ichans = embed_dim if layer == 0 else econv_chans
if use_batch_norm:
self.convs += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, econv_chans, econv_filts, stride=1,
padding=(econv_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(econv_chans),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate))]
else:
self.convs += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, econv_chans, econv_filts, stride=1,
padding=(econv_filts - 1) // 2, bias=False),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate))]
else:
self.convs = None
if elayers > 0:
iunits = econv_chans if econv_layers != 0 else embed_dim
self.blstm = torch.nn.LSTM(
iunits, eunits // 2, elayers,
batch_first=True,
bidirectional=True)
else:
self.blstm = None
# initialize
self.apply(encoder_init)
def forward(self, xs, ilens=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequence of character ids (B, Tmax). Padded value should be 0.
ilens (LongTensor): Batch of lengths of each input batch (B,).
Returns:
Tensor: Batch of the sequences of encoder states(B, Tmax, eunits).
LongTensor: Batch of lengths of each sequence (B,)
"""
xs = self.embed(xs).transpose(1, 2)
if self.convs is not None:
for l in six.moves.range(len(self.convs)):
if self.use_residual:
xs += self.convs[l](xs)
else:
xs = self.convs[l](xs)
if self.blstm is None:
return xs.transpose(1, 2)
xs = pack_padded_sequence(xs.transpose(1, 2), ilens, batch_first=True)
self.blstm.flatten_parameters()
xs, _ = self.blstm(xs) # (B, Tmax, C)
xs, hlens = pad_packed_sequence(xs, batch_first=True)
return xs, hlens
def inference(self, x):
"""Inference.
Args:
x (Tensor): The sequeunce of character ids (T,).
Returns:
Tensor: The sequences of encoder states(T, eunits).
"""
assert len(x.size()) == 1
xs = x.unsqueeze(0)
ilens = [x.size(0)]
return self.forward(xs, ilens)[0][0]
__init__(self, idim, embed_dim=512, elayers=1, eunits=512, econv_layers=3, econv_chans=512, econv_filts=5, use_batch_norm=True, use_residual=False, dropout_rate=0.5, padding_idx=0)
special
¶Initialize Tacotron2 encoder module.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/encoder.py
def __init__(self, idim,
embed_dim=512,
elayers=1,
eunits=512,
econv_layers=3,
econv_chans=512,
econv_filts=5,
use_batch_norm=True,
use_residual=False,
dropout_rate=0.5,
padding_idx=0):
"""Initialize Tacotron2 encoder module.
Args:
idim (int) Dimension of the inputs.
embed_dim (int, optional) Dimension of character embedding.
elayers (int, optional) The number of encoder blstm layers.
eunits (int, optional) The number of encoder blstm units.
econv_layers (int, optional) The number of encoder conv layers.
econv_filts (int, optional) The number of encoder conv filter size.
econv_chans (int, optional) The number of encoder conv filter channels.
use_batch_norm (bool, optional) Whether to use batch normalization.
use_residual (bool, optional) Whether to use residual connection.
dropout_rate (float, optional) Dropout rate.
"""
super(Encoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.use_residual = use_residual
# define network layer modules
self.embed = torch.nn.Embedding(idim, embed_dim, padding_idx=padding_idx)
if econv_layers > 0:
self.convs = torch.nn.ModuleList()
for layer in six.moves.range(econv_layers):
ichans = embed_dim if layer == 0 else econv_chans
if use_batch_norm:
self.convs += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, econv_chans, econv_filts, stride=1,
padding=(econv_filts - 1) // 2, bias=False),
torch.nn.BatchNorm1d(econv_chans),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate))]
else:
self.convs += [torch.nn.Sequential(
torch.nn.Conv1d(ichans, econv_chans, econv_filts, stride=1,
padding=(econv_filts - 1) // 2, bias=False),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate))]
else:
self.convs = None
if elayers > 0:
iunits = econv_chans if econv_layers != 0 else embed_dim
self.blstm = torch.nn.LSTM(
iunits, eunits // 2, elayers,
batch_first=True,
bidirectional=True)
else:
self.blstm = None
# initialize
self.apply(encoder_init)
forward(self, xs, ilens=None)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
xs |
Tensor |
Batch of the padded sequence of character ids (B, Tmax). Padded value should be 0. |
required |
ilens |
LongTensor |
Batch of lengths of each input batch (B,). |
None |
Returns:
Type | Description |
---|---|
Tensor |
Batch of the sequences of encoder states(B, Tmax, eunits). LongTensor: Batch of lengths of each sequence (B,) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/encoder.py
def forward(self, xs, ilens=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequence of character ids (B, Tmax). Padded value should be 0.
ilens (LongTensor): Batch of lengths of each input batch (B,).
Returns:
Tensor: Batch of the sequences of encoder states(B, Tmax, eunits).
LongTensor: Batch of lengths of each sequence (B,)
"""
xs = self.embed(xs).transpose(1, 2)
if self.convs is not None:
for l in six.moves.range(len(self.convs)):
if self.use_residual:
xs += self.convs[l](xs)
else:
xs = self.convs[l](xs)
if self.blstm is None:
return xs.transpose(1, 2)
xs = pack_padded_sequence(xs.transpose(1, 2), ilens, batch_first=True)
self.blstm.flatten_parameters()
xs, _ = self.blstm(xs) # (B, Tmax, C)
xs, hlens = pad_packed_sequence(xs, batch_first=True)
return xs, hlens
inference(self, x)
¶Inference.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
The sequeunce of character ids (T,). |
required |
Returns:
Type | Description |
---|---|
Tensor |
The sequences of encoder states(T, eunits). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/tacotron2/encoder.py
encoder_init(m)
¶
transformer
special
¶
add_sos_eos
¶
Unility functions for Transformer.
add_sos_eos(ys_pad, sos, eos, ignore_id)
¶Add <sos>
and <eos>
labels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ys_pad |
torch.Tensor |
batch of padded target sequences (B, Lmax) |
required |
sos |
int |
index of |
required |
eos |
int |
index of |
required |
ignore_id |
int |
index of padding |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
padded tensor (B, Lmax) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/add_sos_eos.py
def add_sos_eos(ys_pad, sos, eos, ignore_id):
"""
Add `<sos>` and `<eos>` labels.
Arguments:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of `<sos>`
eos (int): index of `<eos>`
ignore_id (int): index of padding
Returns:
torch.Tensor: padded tensor (B, Lmax)
"""
from tools.espnet_minimal import pad_list
_sos = ys_pad.new([sos])
_eos = ys_pad.new([eos])
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
attention
¶
Multi-Head Attention layer definition.
MultiHeadedAttention (Module)
¶Multi-Head Attention layer.
:param int n_head: the number of head s :param int n_feat: the number of features :param float dropout_rate: dropout rate
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/attention.py
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward(self, query, key, value, mask):
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, head, time1, time2)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
__init__(self, n_head, n_feat, dropout_rate)
special
¶Construct an MultiHeadedAttention object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/attention.py
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
forward(self, query, key, value, mask)
¶Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed value
(batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/attention.py
def forward(self, query, key, value, mask):
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # (batch, head, time1, time2)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
decoder
¶
Decoder definition.
Decoder (BatchScorerInterface, Module)
¶Transfomer decoder module.
:param int odim: output dim :param int attention_dim: dimention of attention :param int attention_heads: the number of heads of multi head attention :param int linear_units: the number of units of position-wise feed forward :param int num_blocks: the number of decoder blocks :param float dropout_rate: dropout rate :param float attention_dropout_rate: dropout rate for attention :param str or torch.nn.Module input_layer: input layer type :param bool use_output_layer: whether to use output layer :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder.py
class Decoder(BatchScorerInterface, torch.nn.Module):
"""Transfomer decoder module.
:param int odim: output dim
:param int attention_dim: dimention of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate for attention
:param str or torch.nn.Module input_layer: input layer type
:param bool use_output_layer: whether to use output layer
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(self, odim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0,
input_layer="embed",
use_output_layer=True,
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False):
"""Construct an Decoder object."""
torch.nn.Module.__init__(self)
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(odim, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(odim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
self.normalize_before = normalize_before
self.decoders = repeat(
num_blocks,
lambda: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after
)
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, odim)
else:
self.output_layer = None
def forward(self, tgt, tgt_mask, memory, memory_mask):
"""Forward decoder.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:return x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True,
final block outputs (batch, maxlen_out, attention_dim) in the other cases
:rtype: torch.Tensor
:return tgt_mask: score mask before softmax (batch, maxlen_out)
:rtype: torch.Tensor
"""
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
return x, tgt_mask
def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
"""Forward one step.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param List[torch.Tensor] cache: cached output list of (batch, max_time_out-1, size)
:return y, cache: NN output value and cache per `self.decoders`.
`y.shape` is (batch, maxlen_out, token)
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
# beam search API (see ScorerInterface)
def score(self, ys, state, x):
"""Score."""
# TODO(karita): remove this section after all ScorerInterface implements batch decoding
if ys.dim() == 1:
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if state[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([state[b][l] for b in range(n_batch)]) for l in range(n_layers)]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(ys, ys_mask, x, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[state[l][b] for l in range(n_layers)] for b in range(n_batch)]
return logp, state_list
__init__(self, odim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, self_attention_dropout_rate=0.0, src_attention_dropout_rate=0.0, input_layer='embed', use_output_layer=True, pos_enc_class=<class 'tools.espnet_minimal.nets.pytorch_backend.transformer.embedding.PositionalEncoding'>, normalize_before=True, concat_after=False)
special
¶Construct an Decoder object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder.py
def __init__(self, odim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0,
input_layer="embed",
use_output_layer=True,
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False):
"""Construct an Decoder object."""
torch.nn.Module.__init__(self)
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(odim, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(odim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
self.normalize_before = normalize_before
self.decoders = repeat(
num_blocks,
lambda: DecoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after
)
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, odim)
else:
self.output_layer = None
forward(self, tgt, tgt_mask, memory, memory_mask)
¶Forward decoder.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) if input_layer == "embed" input tensor (batch, maxlen_out, #mels) in the other cases :param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) :param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat) :param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in) dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (include 1.2) :return x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True, final block outputs (batch, maxlen_out, attention_dim) in the other cases :rtype: torch.Tensor :return tgt_mask: score mask before softmax (batch, maxlen_out) :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder.py
def forward(self, tgt, tgt_mask, memory, memory_mask):
"""Forward decoder.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out) if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param torch.Tensor memory_mask: encoded memory mask, (batch, maxlen_in)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:return x: decoded token score before softmax (batch, maxlen_out, token) if use_output_layer is True,
final block outputs (batch, maxlen_out, attention_dim) in the other cases
:rtype: torch.Tensor
:return tgt_mask: score mask before softmax (batch, maxlen_out)
:rtype: torch.Tensor
"""
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
return x, tgt_mask
forward_one_step(self, tgt, tgt_mask, memory, cache=None)
¶Forward one step.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param List[torch.Tensor] cache: cached output list of (batch, max_time_out-1, size)
:return y, cache: NN output value and cache per self.decoders
.
y.shape
is (batch, maxlen_out, token)
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder.py
def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
"""Forward one step.
:param torch.Tensor tgt: input token ids, int64 (batch, maxlen_out)
:param torch.Tensor tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
:param torch.Tensor memory: encoded memory, float32 (batch, maxlen_in, feat)
:param List[torch.Tensor] cache: cached output list of (batch, max_time_out-1, size)
:return y, cache: NN output value and cache per `self.decoders`.
`y.shape` is (batch, maxlen_out, token)
:rtype: Tuple[torch.Tensor, List[torch.Tensor]]
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
score(self, ys, state, x)
¶Score.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder.py
def score(self, ys, state, x):
"""Score."""
# TODO(karita): remove this section after all ScorerInterface implements batch decoding
if ys.dim() == 1:
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
return logp.squeeze(0), state
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if state[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [torch.stack([state[b][l] for b in range(n_batch)]) for l in range(n_layers)]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=x.device).unsqueeze(0)
logp, state = self.forward_one_step(ys, ys_mask, x, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[state[l][b] for l in range(n_layers)] for b in range(n_batch)]
return logp, state_list
decoder_layer
¶
Decoder self-attention layer definition.
DecoderLayer (Module)
¶Single decoder layer module.
:param int size: input dim :param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module :param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention src_attn: source attention module :param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward: feed forward layer module :param float dropout_rate: dropout rate :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder_layer.py
class DecoderLayer(nn.Module):
"""Single decoder layer module.
:param int size: input dim
:param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module
:param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention src_attn: source attention module
:param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward:
feed forward layer module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(self, size, self_attn, src_attn, feed_forward, dropout_rate,
normalize_before=True, concat_after=False):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): decoded previous target features (batch, max_time_out, size)
tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
memory (torch.Tensor): encoded source features (batch, max_time_in, size)
memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
cache (torch.Tensor): cached output (batch, max_time_out-1, size)
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (tgt.shape[0], tgt.shape[1] - 1, self.size), \
f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
__init__(self, size, self_attn, src_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False)
special
¶Construct an DecoderLayer object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder_layer.py
def __init__(self, size, self_attn, src_attn, feed_forward, dropout_rate,
normalize_before=True, concat_after=False):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
forward(self, tgt, tgt_mask, memory, memory_mask, cache=None)
¶Compute decoded features.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tgt |
torch.Tensor |
decoded previous target features (batch, max_time_out, size) |
required |
tgt_mask |
torch.Tensor |
mask for x (batch, max_time_out) |
required |
memory |
torch.Tensor |
encoded source features (batch, max_time_in, size) |
required |
memory_mask |
torch.Tensor |
mask for memory (batch, max_time_in) |
required |
cache |
torch.Tensor |
cached output (batch, max_time_out-1, size) |
None |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/decoder_layer.py
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): decoded previous target features (batch, max_time_out, size)
tgt_mask (torch.Tensor): mask for x (batch, max_time_out)
memory (torch.Tensor): encoded source features (batch, max_time_in, size)
memory_mask (torch.Tensor): mask for memory (batch, max_time_in)
cache (torch.Tensor): cached output (batch, max_time_out-1, size)
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (tgt.shape[0], tgt.shape[1] - 1, self.size), \
f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
embedding
¶
Positonal Encoding Module.
PositionalEncoding (Module)
¶Positional encoding.
:param int d_model: embedding dim :param float dropout_rate: dropout rate :param int max_len: maximum input length
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, :x.size(1)]
return self.dropout(x)
__init__(self, d_model, dropout_rate, max_len=5000)
special
¶Construct an PositionalEncoding object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
extend_pe(self, x)
¶Reset the positional encodings.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2, dtype=torch.float32) *
-(math.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
forward(self, x)
¶Add positional encoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input. Its shape is (batch, time, ...) |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
Encoded tensor. Its shape is (batch, time, ...) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
ScaledPositionalEncoding (PositionalEncoding)
¶Scaled positional encoding module.
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
class ScaledPositionalEncoding(PositionalEncoding):
"""Scaled positional encoding module.
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, :x.size(1)]
return self.dropout(x)
__init__(self, d_model, dropout_rate, max_len=5000)
special
¶Initialize class.
:param int d_model: embedding dim :param float dropout_rate: dropout rate :param int max_len: maximum input length
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
forward(self, x)
¶Add positional encoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
Input. Its shape is (batch, time, ...) |
required |
Returns:
Type | Description |
---|---|
torch.Tensor |
Encoded tensor. Its shape is (batch, time, ...) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/embedding.py
reset_parameters(self)
¶
encoder
¶
Encoder definition.
Encoder (Module)
¶Transformer encoder module.
:param int idim: input dim :param int attention_dim: dimention of attention :param int attention_heads: the number of heads of multi head attention :param int linear_units: the number of units of position-wise feed forward :param int num_blocks: the number of decoder blocks :param float dropout_rate: dropout rate :param float attention_dropout_rate: dropout rate in attention :param float positional_dropout_rate: dropout rate after adding positional encoding :param str or torch.nn.Module input_layer: input layer type :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x) :param str positionwise_layer_type: linear of conv1d :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer :param int padding_idx: padding_idx for input_layer=embed
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder.py
class Encoder(torch.nn.Module):
"""Transformer encoder module.
:param int idim: input dim
:param int attention_dim: dimention of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate in attention
:param float positional_dropout_rate: dropout rate after adding positional encoding
:param str or torch.nn.Module input_layer: input layer type
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
:param str positionwise_layer_type: linear of conv1d
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
:param int padding_idx: padding_idx for input_layer=embed
"""
def __init__(self, idim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
padding_idx=-1):
"""Construct an Encoder object."""
super(Encoder, self).__init__()
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda: EncoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after
)
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
def forward(self, xs, masks):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if isinstance(self.embed, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks
__init__(self, idim, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, positional_dropout_rate=0.1, attention_dropout_rate=0.0, input_layer='conv2d', pos_enc_class=<class 'tools.espnet_minimal.nets.pytorch_backend.transformer.embedding.PositionalEncoding'>, normalize_before=True, concat_after=False, positionwise_layer_type='linear', positionwise_conv_kernel_size=1, padding_idx=-1)
special
¶Construct an Encoder object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder.py
def __init__(self, idim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
padding_idx=-1):
"""Construct an Encoder object."""
super(Encoder, self).__init__()
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate)
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate)
else:
raise NotImplementedError("Support only linear or conv1d.")
self.encoders = repeat(
num_blocks,
lambda: EncoderLayer(
attention_dim,
MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after
)
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
forward(self, xs, masks)
¶Embed positions in tensor.
:param torch.Tensor xs: input tensor :param torch.Tensor masks: input mask :return: position embedded tensor and mask :rtype Tuple[torch.Tensor, torch.Tensor]:
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder.py
def forward(self, xs, masks):
"""Embed positions in tensor.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if isinstance(self.embed, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks
encoder_layer
¶
Encoder self-attention layer definition.
EncoderLayer (Module)
¶Encoder layer module.
:param int size: input dim :param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module :param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward: feed forward module :param float dropout_rate: dropout rate :param bool normalize_before: whether to use layer_norm before the first block :param bool concat_after: whether to concat attention layer's input and output if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) if False, no additional linear will be applied. i.e. x -> x + att(x)
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder_layer.py
class EncoderLayer(nn.Module):
"""Encoder layer module.
:param int size: input dim
:param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module
:param services.hci.speech.espnet_minimal.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward:
feed forward module
:param float dropout_rate: dropout rate
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(self, size, self_attn, feed_forward, dropout_rate,
normalize_before=True, concat_after=False):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(self, x, mask):
"""Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask
__init__(self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False)
special
¶Construct an EncoderLayer object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder_layer.py
def __init__(self, size, self_attn, feed_forward, dropout_rate,
normalize_before=True, concat_after=False):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
forward(self, x, mask)
¶Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size) :param torch.Tensor mask: mask for x (batch, max_time_in) :rtype: Tuple[torch.Tensor, torch.Tensor]
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/encoder_layer.py
def forward(self, x, mask):
"""Compute encoded features.
:param torch.Tensor x: encoded source features (batch, max_time_in, size)
:param torch.Tensor mask: mask for x (batch, max_time_in)
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask
initializer
¶
Parameter initialization.
initialize(model, init_type='pytorch')
¶Initialize Transformer module.
:param torch.nn.Module model: transformer instance :param str init_type: initialization type
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/initializer.py
def initialize(model, init_type="pytorch"):
"""Initialize Transformer module.
:param torch.nn.Module model: transformer instance
:param str init_type: initialization type
"""
if init_type == "pytorch":
return
# weight init
for p in model.parameters():
if p.dim() > 1:
if init_type == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init_type == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init_type == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init_type == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init_type)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
# reset some modules with default init
for m in model.modules():
if isinstance(m, (torch.nn.Embedding, LayerNorm)):
m.reset_parameters()
label_smoothing_loss
¶
Label smoothing module.
LabelSmoothingLoss (Module)
¶Label-smoothing loss.
:param int size: the number of class :param int padding_idx: ignored class id :param float smoothing: smoothing rate (0.0 means the conventional CE) :param bool normalize_length: normalize loss by sequence length if True :param torch.nn.Module criterion: loss function to be smoothed
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/label_smoothing_loss.py
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param float smoothing: smoothing rate (0.0 means the conventional CE)
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function to be smoothed
"""
def __init__(self, size, padding_idx, smoothing, normalize_length=False, criterion=nn.KLDivLoss(reduce=False)):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
__init__(self, size, padding_idx, smoothing, normalize_length=False, criterion=KLDivLoss())
special
¶Construct an LabelSmoothingLoss object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/label_smoothing_loss.py
def __init__(self, size, padding_idx, smoothing, normalize_length=False, criterion=nn.KLDivLoss(reduce=False)):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
forward(self, x, target)
¶Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class) :param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen) :return: scalar float value :rtype torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/label_smoothing_loss.py
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target: target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
layer_norm
¶
Layer normalization module.
LayerNorm (LayerNorm)
¶Layer normalization module.
:param int nout: output dim size :param int dim: dimension to be normalized
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/layer_norm.py
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
:param int nout: output dim size
:param int dim: dimension to be normalized
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
:param torch.Tensor x: input tensor
:return: layer normalized tensor
:rtype torch.Tensor
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
mask
¶
Mask module.
datatype
¶is_torch_1_2_plus
¶subsequent_mask(size, device='cpu', dtype=torch.bool)
¶Create mask for subsequent steps (1, size, size).
:param int size: size of mask :param str device: "cpu" or "cuda" or torch.Tensor.device :param torch.dtype dtype: result dtype :rtype: torch.Tensor
subsequent_mask(3) [[1, 0, 0], [1, 1, 0], [1, 1, 1]]
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/mask.py
def subsequent_mask(size, device="cpu", dtype=datatype):
"""Create mask for subsequent steps (1, size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
target_mask(ys_in_pad, ignore_id)
¶Create mask for decoder self-attention.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :param int ignore_id: index of padding :param torch.dtype dtype: result dtype :rtype: torch.Tensor
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/mask.py
def target_mask(ys_in_pad, ignore_id):
"""Create mask for decoder self-attention.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int ignore_id: index of padding
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
"""
ys_mask = ys_in_pad != ignore_id
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
multi_layer_conv
¶
Layer modules for FFT block in FastSpeech (Feed-forward Transformer).
Conv1dLinear (Module)
¶Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(Conv1dLinear, self).__init__()
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x))
__init__(self, in_chans, hidden_chans, kernel_size, dropout_rate)
special
¶Initialize Conv1dLinear module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_chans |
int |
Number of input channels. |
required |
hidden_chans |
int |
Number of hidden channels. |
required |
kernel_size |
int |
Kernel size of conv1d. |
required |
dropout_rate |
float |
Dropout rate. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(Conv1dLinear, self).__init__()
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Batch of input tensors (B, ..., in_chans). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of output tensors (B, ..., hidden_chans). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
MultiLayeredConv1d (Module)
¶Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network
in Transforner block, which is introduced in FastSpeech: Fast, Robust and Controllable Text to Speech
_.
.. _FastSpeech: Fast, Robust and Controllable Text to Speech
:
https://arxiv.org/pdf/1905.09263.pdf
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
class MultiLayeredConv1d(torch.nn.Module):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network
in Transforner block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(MultiLayeredConv1d, self).__init__()
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
__init__(self, in_chans, hidden_chans, kernel_size, dropout_rate)
special
¶Initialize MultiLayeredConv1d module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_chans |
int |
Number of input channels. |
required |
hidden_chans |
int |
Number of hidden channels. |
required |
kernel_size |
int |
Kernel size of conv1d. |
required |
dropout_rate |
float |
Dropout rate. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(MultiLayeredConv1d, self).__init__()
self.w_1 = torch.nn.Conv1d(in_chans, hidden_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.w_2 = torch.nn.Conv1d(hidden_chans, in_chans, kernel_size,
stride=1, padding=(kernel_size - 1) // 2)
self.dropout = torch.nn.Dropout(dropout_rate)
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Batch of input tensors (B, ..., in_chans). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Batch of output tensors (B, ..., hidden_chans). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/multi_layer_conv.py
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., in_chans).
Returns:
Tensor: Batch of output tensors (B, ..., hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
optimizer
¶
Optimizer module.
NoamOpt
¶Optim wrapper that implements rate.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/optimizer.py
class NoamOpt(object):
"""Optim wrapper that implements rate."""
def __init__(self, model_size, factor, warmup, optimizer):
"""Construct an NoamOpt object."""
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return self.factor * self.model_size ** (-0.5) \
* min(step ** (-0.5), step * self.warmup ** (-1.5))
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict()
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
param_groups
property
readonly
¶Return param_groups.
__init__(self, model_size, factor, warmup, optimizer)
special
¶Construct an NoamOpt object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/optimizer.py
load_state_dict(self, state_dict)
¶Load state_dict.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/optimizer.py
rate(self, step=None)
¶Implement lrate
above.
state_dict(self)
¶Return state_dict.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/optimizer.py
step(self)
¶zero_grad(self)
¶get_std_opt(model, d_model, warmup, factor)
¶Get standard NoamOpt.
plot
¶
PlotAttentionReport
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/plot.py
class PlotAttentionReport():
def plotfn(self, *args, **kwargs):
plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer):
attn_dict = self.get_attention_weights()
suffix = "ep.{.updater.epoch}.png".format(trainer)
self.plotfn(self.data, attn_dict, self.outdir, suffix, savefig)
def get_attention_weights(self):
batch = self.converter([self.transform(self.data)], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
elif isinstance(batch, dict):
att_ws = self.att_vis_fn(**batch)
return att_ws
def log_attentions(self, logger, step):
def log_fig(plot, filename):
from os.path import basename
logger.add_figure(basename(filename), plot, step)
plt.clf()
attn_dict = self.get_attention_weights()
self.plotfn(self.data, attn_dict, self.outdir, "", log_fig)
__call__(self, trainer)
special
¶get_attention_weights(self)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/plot.py
log_attentions(self, logger, step)
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/plot.py
plotfn(self, *args, **kwargs)
¶plot_multi_head_attention(data, attn_dict, outdir, suffix='png', savefn=<function savefig at 0x7f1f955e4790>)
¶Plot multi head attentions
:param dict data: utts info from json file :param dict[str, torch.Tensor] attn_dict: multi head attention dict. values should be torch.Tensor (head, input_length, output_length) :param str outdir: dir to save fig :param str suffix: filename suffix including image type (e.g., png) :param savefn: function to save
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/plot.py
def plot_multi_head_attention(data, attn_dict, outdir, suffix="png", savefn=savefig):
"""Plot multi head attentions
:param dict data: utts info from json file
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
"""
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.%s.%s" % (
outdir, data[idx][0], name, suffix)
dec_len = int(data[idx][1]['output'][0]['shape'][0])
enc_len = int(data[idx][1]['input'][0]['shape'][0])
if "encoder" in name:
att_w = att_w[:, :enc_len, :enc_len]
elif "decoder" in name:
if "self" in name:
att_w = att_w[:, :dec_len, :dec_len]
else:
att_w = att_w[:, :dec_len, :enc_len]
else:
logging.warning("unknown name for shaping attention")
fig = _plot_and_save_attention(att_w, filename)
savefn(fig, filename)
savefig(plot, filename)
¶
positionwise_feed_forward
¶
Positionwise feed forward layer definition.
PositionwiseFeedForward (Module)
¶Positionwise feed forward layer.
:param int idim: input dimenstion :param int hidden_units: number of hidden units :param float dropout_rate: dropout rate
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/positionwise_feed_forward.py
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
:param int idim: input dimenstion
:param int hidden_units: number of hidden units
:param float dropout_rate: dropout rate
"""
def __init__(self, idim, hidden_units, dropout_rate):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Forward funciton."""
return self.w_2(self.dropout(torch.relu(self.w_1(x))))
repeat
¶
Repeat the same layer definition.
MultiSequential (Sequential)
¶Multi-input multi-output torch.nn.Sequential.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/repeat.py
forward(self, *args)
¶repeat(N, fn)
¶Repeat module N times.
:param int N: repeat time :param function fn: function to generate module :return: repeated modules :rtype: MultiSequential
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/repeat.py
subsampling
¶
Subsampling layer definition.
Conv2dSubsampling (Module)
¶Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim :param int odim: output dim :param flaot dropout_rate: dropout rate
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/subsampling.py
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, odim, dropout_rate):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU()
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
PositionalEncoding(odim, dropout_rate)
)
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
__init__(self, idim, odim, dropout_rate)
special
¶Construct an Conv2dSubsampling object.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/subsampling.py
def __init__(self, idim, odim, dropout_rate):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU()
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
PositionalEncoding(odim, dropout_rate)
)
forward(self, x, x_mask)
¶Subsample x.
:param torch.Tensor x: input tensor :param torch.Tensor x_mask: input mask :return: subsampled x and mask :rtype Tuple[torch.Tensor, torch.Tensor]
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/transformer/subsampling.py
def forward(self, x, x_mask):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor, torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
wavenet
¶
This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder.
CausalConv1d (Module)
¶
1D dilated causal convolution.
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
class CausalConv1d(nn.Module):
"""1D dilated causal convolution."""
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
super(CausalConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, in_channels, T).
Returns:
Tensor: Tensor with the shape (B, out_channels, T)
"""
x = self.conv(x)
if self.padding != 0:
x = x[:, :, :-self.padding]
return x
__init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
super(CausalConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
padding=padding, dilation=dilation, bias=bias)
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Input tensor with the shape (B, in_channels, T). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Tensor with the shape (B, out_channels, T) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
OneHot (Module)
¶
Convert to one-hot vector.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
depth |
int |
Dimension of one-hot vector. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
class OneHot(nn.Module):
"""Convert to one-hot vector.
Args:
depth (int): Dimension of one-hot vector.
"""
def __init__(self, depth):
super(OneHot, self).__init__()
self.depth = depth
def forward(self, x):
"""Calculate forward propagation.
Args:
x (LongTensor): long tensor variable with the shape (B, T)
Returns:
Tensor: float tensor variable with the shape (B, depth, T)
"""
x = x % self.depth
x = torch.unsqueeze(x, 2)
x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float()
return x_onehot.scatter_(2, x, 1)
__init__(self, depth)
special
¶forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
LongTensor |
long tensor variable with the shape (B, T) |
required |
Returns:
Type | Description |
---|---|
Tensor |
float tensor variable with the shape (B, depth, T) |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def forward(self, x):
"""Calculate forward propagation.
Args:
x (LongTensor): long tensor variable with the shape (B, T)
Returns:
Tensor: float tensor variable with the shape (B, depth, T)
"""
x = x % self.depth
x = torch.unsqueeze(x, 2)
x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float()
return x_onehot.scatter_(2, x, 1)
UpSampling (Module)
¶
Upsampling layer with deconvolution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
upsampling_factor |
int |
Upsampling factor. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
class UpSampling(nn.Module):
"""Upsampling layer with deconvolution.
Args:
upsampling_factor (int): Upsampling factor.
"""
def __init__(self, upsampling_factor, bias=True):
super(UpSampling, self).__init__()
self.upsampling_factor = upsampling_factor
self.bias = bias
self.conv = nn.ConvTranspose2d(1, 1,
kernel_size=(1, self.upsampling_factor),
stride=(1, self.upsampling_factor),
bias=self.bias)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, C, T)
Returns:
Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor.
"""
x = x.unsqueeze(1) # B x 1 x C x T
x = self.conv(x) # B x 1 x C x T'
return x.squeeze(1)
__init__(self, upsampling_factor, bias=True)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
forward(self, x)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
Tensor |
Input tensor with the shape (B, C, T) |
required |
Returns:
Type | Description |
---|---|
Tensor |
Tensor with the shape (B, C, T') where T' = T * upsampling_factor. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
WaveNet (Module)
¶
Conditional wavenet.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_quantize |
int |
Number of quantization. |
256 |
n_aux |
int |
Number of aux feature dimension. |
28 |
n_resch |
int |
Number of filter channels for residual block. |
512 |
n_skipch |
int |
Number of filter channels for skip connection. |
256 |
dilation_depth |
int |
Number of dilation depth (e.g. if set 10, max dilation = 2^(10-1)). |
10 |
dilation_repeat |
int |
Number of dilation repeat. |
3 |
kernel_size |
int |
Filter size of dilated causal convolution. |
2 |
upsampling_factor |
int |
Upsampling factor. |
0 |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
class WaveNet(nn.Module):
"""Conditional wavenet.
Args:
n_quantize (int): Number of quantization.
n_aux (int): Number of aux feature dimension.
n_resch (int): Number of filter channels for residual block.
n_skipch (int): Number of filter channels for skip connection.
dilation_depth (int): Number of dilation depth (e.g. if set 10, max dilation = 2^(10-1)).
dilation_repeat (int): Number of dilation repeat.
kernel_size (int): Filter size of dilated causal convolution.
upsampling_factor (int): Upsampling factor.
"""
def __init__(self, n_quantize=256, n_aux=28, n_resch=512, n_skipch=256,
dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0):
super(WaveNet, self).__init__()
self.n_aux = n_aux
self.n_quantize = n_quantize
self.n_resch = n_resch
self.n_skipch = n_skipch
self.kernel_size = kernel_size
self.dilation_depth = dilation_depth
self.dilation_repeat = dilation_repeat
self.upsampling_factor = upsampling_factor
self.dilations = [2 ** i for i in range(self.dilation_depth)] * self.dilation_repeat
self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1
# for preprocessing
self.onehot = OneHot(self.n_quantize)
self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size)
if self.upsampling_factor > 0:
self.upsampling = UpSampling(self.upsampling_factor)
# for residual blocks
self.dil_sigmoid = nn.ModuleList()
self.dil_tanh = nn.ModuleList()
self.aux_1x1_sigmoid = nn.ModuleList()
self.aux_1x1_tanh = nn.ModuleList()
self.skip_1x1 = nn.ModuleList()
self.res_1x1 = nn.ModuleList()
for d in self.dilations:
self.dil_sigmoid += [CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)]
self.dil_tanh += [CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)]
self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)]
self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)]
# for postprocessing
self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1)
self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)
def forward(self, x, h):
"""Calculate forward propagation.
Args:
x (LongTensor): Quantized input waveform tensor with the shape (B, T).
h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T).
Returns:
Tensor: Logits with the shape (B, T, n_quantize).
"""
# preprocess
output = self._preprocess(x)
if self.upsampling_factor > 0:
h = self.upsampling(h)
# residual block
skip_connections = []
for l in range(len(self.dilations)):
output, skip = self._residual_forward(
output, h, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
skip_connections.append(skip)
# skip-connection part
output = sum(skip_connections)
output = self._postprocess(output)
return output
def generate(self, x, h, n_samples, interval=None, mode="sampling"):
"""Generate a waveform with fast genration algorithm.
This generation based on `Fast WaveNet Generation Algorithm`_.
Args:
x (LongTensor): Initial waveform tensor with the shape (T,).
h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux).
n_samples (int): Number of samples to be generated.
interval (int, optional): Log interval.
mode (str, optional): "sampling" or "argmax".
Return:
ndarray: Generated quantized waveform (n_samples).
.. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482
"""
# reshape inputs
assert len(x.shape) == 1
assert len(h.shape) == 2 and h.shape[1] == self.n_aux
x = x.unsqueeze(0)
h = h.transpose(0, 1).unsqueeze(0)
# perform upsampling
if self.upsampling_factor > 0:
h = self.upsampling(h)
# padding for shortage
if n_samples > h.shape[2]:
h = F.pad(h, (0, n_samples - h.shape[2]), "replicate")
# padding if the length less than
n_pad = self.receptive_field - x.size(1)
if n_pad > 0:
x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2)
h = F.pad(h, (n_pad, 0), "replicate")
# prepare buffer
output = self._preprocess(x)
h_ = h[:, :, :x.size(1)]
output_buffer = []
buffer_size = []
for l, d in enumerate(self.dilations):
output, _ = self._residual_forward(
output, h_, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
if d == 2 ** (self.dilation_depth - 1):
buffer_size.append(self.kernel_size - 1)
else:
buffer_size.append(d * 2 * (self.kernel_size - 1))
output_buffer.append(output[:, :, -buffer_size[l] - 1: -1])
# generate
samples = x[0]
start_time = time.time()
for i in range(n_samples):
output = samples[-self.kernel_size * 2 + 1:].unsqueeze(0)
output = self._preprocess(output)
h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1)
output_buffer_next = []
skip_connections = []
for l, d in enumerate(self.dilations):
output, skip = self._generate_residual_forward(
output, h_, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
output = torch.cat([output_buffer[l], output], dim=2)
output_buffer_next.append(output[:, :, -buffer_size[l]:])
skip_connections.append(skip)
# update buffer
output_buffer = output_buffer_next
# get predicted sample
output = sum(skip_connections)
output = self._postprocess(output)[0]
if mode == "sampling":
posterior = F.softmax(output[-1], dim=0)
dist = torch.distributions.Categorical(posterior)
sample = dist.sample().unsqueeze(0)
elif mode == "argmax":
sample = output.argmax(-1)
else:
logging.error("mode should be sampling or argmax")
sys.exit(1)
samples = torch.cat([samples, sample], dim=0)
# show progress
if interval is not None and (i + 1) % interval == 0:
elapsed_time_per_sample = (time.time() - start_time) / interval
logging.info("%d/%d estimated time = %.3f sec (%.3f sec / sample)" % (
i + 1, n_samples, (n_samples - i - 1) * elapsed_time_per_sample, elapsed_time_per_sample))
start_time = time.time()
return samples[-n_samples:].cpu().numpy()
def _preprocess(self, x):
x = self.onehot(x).transpose(1, 2)
output = self.causal(x)
return output
def _postprocess(self, x):
output = F.relu(x)
output = self.conv_post_1(output)
output = F.relu(output) # B x C x T
output = self.conv_post_2(output).transpose(1, 2) # B x T x C
return output
def _residual_forward(self, x, h, dil_sigmoid, dil_tanh,
aux_1x1_sigmoid, aux_1x1_tanh, skip_1x1, res_1x1):
output_sigmoid = dil_sigmoid(x)
output_tanh = dil_tanh(x)
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * \
torch.tanh(output_tanh + aux_output_tanh)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x
return output, skip
def _generate_residual_forward(self, x, h, dil_sigmoid, dil_tanh,
aux_1x1_sigmoid, aux_1x1_tanh, skip_1x1, res_1x1):
output_sigmoid = dil_sigmoid(x)[:, :, -1:]
output_tanh = dil_tanh(x)[:, :, -1:]
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * \
torch.tanh(output_tanh + aux_output_tanh)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x[:, :, -1:] # B x C x 1
return output, skip
__init__(self, n_quantize=256, n_aux=28, n_resch=512, n_skipch=256, dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0)
special
¶Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def __init__(self, n_quantize=256, n_aux=28, n_resch=512, n_skipch=256,
dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0):
super(WaveNet, self).__init__()
self.n_aux = n_aux
self.n_quantize = n_quantize
self.n_resch = n_resch
self.n_skipch = n_skipch
self.kernel_size = kernel_size
self.dilation_depth = dilation_depth
self.dilation_repeat = dilation_repeat
self.upsampling_factor = upsampling_factor
self.dilations = [2 ** i for i in range(self.dilation_depth)] * self.dilation_repeat
self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1
# for preprocessing
self.onehot = OneHot(self.n_quantize)
self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size)
if self.upsampling_factor > 0:
self.upsampling = UpSampling(self.upsampling_factor)
# for residual blocks
self.dil_sigmoid = nn.ModuleList()
self.dil_tanh = nn.ModuleList()
self.aux_1x1_sigmoid = nn.ModuleList()
self.aux_1x1_tanh = nn.ModuleList()
self.skip_1x1 = nn.ModuleList()
self.res_1x1 = nn.ModuleList()
for d in self.dilations:
self.dil_sigmoid += [CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)]
self.dil_tanh += [CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)]
self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)]
self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)]
# for postprocessing
self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1)
self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)
forward(self, x, h)
¶Calculate forward propagation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
LongTensor |
Quantized input waveform tensor with the shape (B, T). |
required |
h |
Tensor |
Auxiliary feature tensor with the shape (B, n_aux, T). |
required |
Returns:
Type | Description |
---|---|
Tensor |
Logits with the shape (B, T, n_quantize). |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def forward(self, x, h):
"""Calculate forward propagation.
Args:
x (LongTensor): Quantized input waveform tensor with the shape (B, T).
h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T).
Returns:
Tensor: Logits with the shape (B, T, n_quantize).
"""
# preprocess
output = self._preprocess(x)
if self.upsampling_factor > 0:
h = self.upsampling(h)
# residual block
skip_connections = []
for l in range(len(self.dilations)):
output, skip = self._residual_forward(
output, h, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
skip_connections.append(skip)
# skip-connection part
output = sum(skip_connections)
output = self._postprocess(output)
return output
generate(self, x, h, n_samples, interval=None, mode='sampling')
¶Generate a waveform with fast genration algorithm.
This generation based on Fast WaveNet Generation Algorithm
_.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
LongTensor |
Initial waveform tensor with the shape (T,). |
required |
h |
Tensor |
Auxiliary feature tensor with the shape (n_samples + T, n_aux). |
required |
n_samples |
int |
Number of samples to be generated. |
required |
interval |
int |
Log interval. |
None |
mode |
str |
"sampling" or "argmax". |
'sampling' |
Returns:
Type | Description |
---|---|
ndarray |
Generated quantized waveform (n_samples). |
.. _Fast WaveNet Generation Algorithm
: https://arxiv.org/abs/1611.09482
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def generate(self, x, h, n_samples, interval=None, mode="sampling"):
"""Generate a waveform with fast genration algorithm.
This generation based on `Fast WaveNet Generation Algorithm`_.
Args:
x (LongTensor): Initial waveform tensor with the shape (T,).
h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux).
n_samples (int): Number of samples to be generated.
interval (int, optional): Log interval.
mode (str, optional): "sampling" or "argmax".
Return:
ndarray: Generated quantized waveform (n_samples).
.. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482
"""
# reshape inputs
assert len(x.shape) == 1
assert len(h.shape) == 2 and h.shape[1] == self.n_aux
x = x.unsqueeze(0)
h = h.transpose(0, 1).unsqueeze(0)
# perform upsampling
if self.upsampling_factor > 0:
h = self.upsampling(h)
# padding for shortage
if n_samples > h.shape[2]:
h = F.pad(h, (0, n_samples - h.shape[2]), "replicate")
# padding if the length less than
n_pad = self.receptive_field - x.size(1)
if n_pad > 0:
x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2)
h = F.pad(h, (n_pad, 0), "replicate")
# prepare buffer
output = self._preprocess(x)
h_ = h[:, :, :x.size(1)]
output_buffer = []
buffer_size = []
for l, d in enumerate(self.dilations):
output, _ = self._residual_forward(
output, h_, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
if d == 2 ** (self.dilation_depth - 1):
buffer_size.append(self.kernel_size - 1)
else:
buffer_size.append(d * 2 * (self.kernel_size - 1))
output_buffer.append(output[:, :, -buffer_size[l] - 1: -1])
# generate
samples = x[0]
start_time = time.time()
for i in range(n_samples):
output = samples[-self.kernel_size * 2 + 1:].unsqueeze(0)
output = self._preprocess(output)
h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1)
output_buffer_next = []
skip_connections = []
for l, d in enumerate(self.dilations):
output, skip = self._generate_residual_forward(
output, h_, self.dil_sigmoid[l], self.dil_tanh[l],
self.aux_1x1_sigmoid[l], self.aux_1x1_tanh[l],
self.skip_1x1[l], self.res_1x1[l])
output = torch.cat([output_buffer[l], output], dim=2)
output_buffer_next.append(output[:, :, -buffer_size[l]:])
skip_connections.append(skip)
# update buffer
output_buffer = output_buffer_next
# get predicted sample
output = sum(skip_connections)
output = self._postprocess(output)[0]
if mode == "sampling":
posterior = F.softmax(output[-1], dim=0)
dist = torch.distributions.Categorical(posterior)
sample = dist.sample().unsqueeze(0)
elif mode == "argmax":
sample = output.argmax(-1)
else:
logging.error("mode should be sampling or argmax")
sys.exit(1)
samples = torch.cat([samples, sample], dim=0)
# show progress
if interval is not None and (i + 1) % interval == 0:
elapsed_time_per_sample = (time.time() - start_time) / interval
logging.info("%d/%d estimated time = %.3f sec (%.3f sec / sample)" % (
i + 1, n_samples, (n_samples - i - 1) * elapsed_time_per_sample, elapsed_time_per_sample))
start_time = time.time()
return samples[-n_samples:].cpu().numpy()
decode_mu_law(y, mu=256)
¶
Perform mu-law decoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
Quantized audio signal with the range from 0 to mu - 1. |
required |
mu |
int |
Quantized level. |
256 |
Returns:
Type | Description |
---|---|
ndarray |
Audio signal with the range from -1 to 1. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def decode_mu_law(y, mu=256):
"""Perform mu-law decoding.
Args:
x (ndarray): Quantized audio signal with the range from 0 to mu - 1.
mu (int): Quantized level.
Returns:
ndarray: Audio signal with the range from -1 to 1.
"""
mu = mu - 1
fx = (y - 0.5) / mu * 2 - 1
x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1)
return x
encode_mu_law(x, mu=256)
¶
Perform mu-law encoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
ndarray |
Audio signal with the range from -1 to 1. |
required |
mu |
int |
Quantized level. |
256 |
Returns:
Type | Description |
---|---|
ndarray |
Quantized audio signal with the range from 0 to mu - 1. |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
def encode_mu_law(x, mu=256):
"""Perform mu-law encoding.
Args:
x (ndarray): Audio signal with the range from -1 to 1.
mu (int): Quantized level.
Returns:
ndarray: Quantized audio signal with the range from 0 to mu - 1.
"""
mu = mu - 1
fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64)
initialize(m)
¶
Initilize conv layers with xavier.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
m |
torch.nn.Module |
Torch module. |
required |
Source code in adviser/tools/espnet_minimal/nets/pytorch_backend/wavenet.py
scorer_interface
¶
Scorer interface module.
BatchScorerInterface (ScorerInterface)
¶
Batch scorer interface.
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
class BatchScorerInterface(ScorerInterface):
"""Batch scorer interface."""
def score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
raise NotImplementedError
score(self, ys, states, xs)
¶
Score new token batch (required).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ys |
torch.Tensor |
torch.int64 prefix tokens (n_batch, ylen). |
required |
states |
List[Any] |
Scorer states for prefix tokens. |
required |
xs |
torch.Tensor |
The encoder feature that generates ys (n_batch, xlen, n_feat). |
required |
Returns:
Type | Description |
---|---|
tuple[torch.Tensor, List[Any]] |
Tuple of
batchfied scores for next token with shape of |
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
def score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
raise NotImplementedError
PartialScorerInterface (ScorerInterface)
¶
Partial scorer interface for beam search.
The partial scorer performs scoring when non-partial scorer finished scoring, and recieves pre-pruned next tokens to score because it is too heavy to score all the tokens.
Examples:
- Prefix search for connectionist-temporal-classification models
- :class:
services.hci.speech.espnet_minimal.nets.scorers.ctc.CTCPrefixScorer
- :class:
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
class PartialScorerInterface(ScorerInterface):
"""Partial scorer interface for beam search.
The partial scorer performs scoring when non-partial scorer finished scoring,
and recieves pre-pruned next tokens to score because it is too heavy to score
all the tokens.
Examples:
* Prefix search for connectionist-temporal-classification models
* :class:`services.hci.speech.espnet_minimal.nets.scorers.ctc.CTCPrefixScorer`
"""
def score_partial(self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor) \
-> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
raise NotImplementedError
score_partial(self, y, next_tokens, state, x)
¶
Score new token (required).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y |
torch.Tensor |
1D prefix token |
required |
next_tokens |
torch.Tensor |
torch.int64 next token to score |
required |
state |
Any |
decoder state for prefix tokens |
required |
x |
torch.Tensor |
The encoder feature that generates ys |
required |
Returns:
Type | Description |
---|---|
tuple[torch.Tensor, Any] |
Tuple of a score tensor for y that has a shape |
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
def score_partial(self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor) \
-> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
raise NotImplementedError
ScorerInterface
¶
Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
- Search heuristics
- :class:
services.hci.speech.espnet_minimal.nets.scorers.length_bonus.LengthBonus
- :class:
- Decoder networks of the sequence-to-sequence models
- :class:
services.hci.speech.espnet_minimal.nets.pytorch_backend.nets.transformer.decoder.Decoder
- :class:
services.hci.speech.espnet_minimal.nets.pytorch_backend.nets.rnn.decoders.Decoder
- :class:
- Neural language models
- :class:
services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.transformer.TransformerLM
- :class:
services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.default.DefaultRNNLM
- :class:
services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM
- :class:
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
class ScorerInterface:
"""Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`services.hci.speech.espnet_minimal.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def init_state(self, x: torch.Tensor) -> Any:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return None
def select_state(self, state: Any, i: int) -> Any:
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
Returns:
state: pruned state
"""
return None if state is None else state[i]
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
def final_score(self, state: Any) -> float:
"""Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return 0.0
final_score(self, state)
¶
Score eos (optional).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
Any |
Scorer state for prefix tokens |
required |
Returns:
Type | Description |
---|---|
float |
final score |
init_state(self, x)
¶
Get an initial state for decoding (optional).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
The encoded feature tensor |
required |
Returns: initial state
score(self, y, state, x)
¶
Score new token (required).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y |
torch.Tensor |
1D torch.int64 prefix tokens. |
required |
state |
Any |
Scorer state for prefix tokens |
required |
x |
torch.Tensor |
The encoder feature that generates ys. |
required |
Returns:
Type | Description |
---|---|
tuple[torch.Tensor, Any] |
Tuple of
scores for next token that has a shape of |
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
select_state(self, state, i)
¶
Select state with relative ids in the main beam search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
Any |
Decoder state for prefix tokens |
required |
i |
int |
Index to select a state in the main beam search |
required |
Returns:
Type | Description |
---|---|
state |
pruned state |
Source code in adviser/tools/espnet_minimal/nets/scorer_interface.py
scorers
special
¶
ctc
¶
ScorerInterface implementation for CTC.
CTCPrefixScorer (PartialScorerInterface)
¶
Decoder interface wrapper for CTCPrefixScore.
Source code in adviser/tools/espnet_minimal/nets/scorers/ctc.py
class CTCPrefixScorer(PartialScorerInterface):
"""Decoder interface wrapper for CTCPrefixScore."""
def __init__(self, ctc: torch.nn.Module, eos: int):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementaiton. For example, :class:`services.hci.speech.espnet_minimal.nets.pytorch_backend.ctc.CTC`
eos (int): The end-of-sequence id.
"""
self.ctc = ctc
self.eos = eos
self.impl = None
def init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
# TODO(karita): use CTCPrefixScoreTH
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
return 0, self.impl.initial_state()
def select_state(self, state, i):
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
Returns:
state: pruned state
"""
sc, st = state
return sc[i], st[i]
def score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
prev_score, state = state
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
tscore = torch.as_tensor(presub_score - prev_score, device=x.device, dtype=x.dtype)
return tscore, (presub_score, new_st)
__init__(self, ctc, eos)
special
¶Initialize class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ctc |
torch.nn.Module |
The CTC implementaiton. For example, :class: |
required |
eos |
int |
The end-of-sequence id. |
required |
Source code in adviser/tools/espnet_minimal/nets/scorers/ctc.py
init_state(self, x)
¶Get an initial state for decoding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor |
The encoded feature tensor |
required |
Returns: initial state
Source code in adviser/tools/espnet_minimal/nets/scorers/ctc.py
def init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
# TODO(karita): use CTCPrefixScoreTH
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
return 0, self.impl.initial_state()
score_partial(self, y, ids, state, x)
¶Score new token.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
y |
torch.Tensor |
1D prefix token |
required |
next_tokens |
torch.Tensor |
torch.int64 next token to score |
required |
state |
decoder state for prefix tokens |
required | |
x |
torch.Tensor |
2D encoder feature that generates ys |
required |
Returns:
Type | Description |
---|---|
tuple[torch.Tensor, Any] |
Tuple of a score tensor for y that has a shape |
Source code in adviser/tools/espnet_minimal/nets/scorers/ctc.py
def score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]: Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
prev_score, state = state
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
tscore = torch.as_tensor(presub_score - prev_score, device=x.device, dtype=x.dtype)
return tscore, (presub_score, new_st)
select_state(self, state, i)
¶Select state with relative ids in the main beam search.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
Decoder state for prefix tokens |
required | |
i |
int |
Index to select a state in the main beam search |
required |
Returns:
Type | Description |
---|---|
state |
pruned state |
Source code in adviser/tools/espnet_minimal/nets/scorers/ctc.py
tts_interface
¶
TTS Interface realted modules.
TTSInterface
¶
TTS Interface for ESPnet model implementation.
Source code in adviser/tools/espnet_minimal/nets/tts_interface.py
class TTSInterface(object):
"""TTS Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add model specific argments to parser."""
return parser
def __init__(self):
"""Initilize TTS module."""
self.reporter = None
def forward(self, *args, **kwargs):
"""Calculate TTS forward propagation.
Returns:
Tensor: Loss value.
"""
raise NotImplementedError("forward method is not implemented")
def inference(self, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Returns:
Tensor: The sequence of generated features (L, odim).
Tensor: The sequence of stop probabilities (L,).
Tensor: The sequence of attention weights (L, T).
"""
raise NotImplementedError("inference method is not implemented")
def calculate_all_attentions(self, *args, **kwargs):
"""Calculate TTS attention weights.
Args:
Tensor: Batch of attention weights (B, Lmax, Tmax).
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
def load_pretrained_model(self, model_path):
"""Load pretrained model parameters."""
torch_load(model_path, self)
@property
def base_plot_keys(self):
"""Return base key names to plot during training.
The keys should match what `chainer.reporter` reports.
if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.
Returns:
list[str]: Base keys to plot during training.
"""
return ['loss']
base_plot_keys
property
readonly
¶
Return base key names to plot during training.
The keys should match what chainer.reporter
reports.
if you add the key loss
, the reporter will report main/loss
and validation/main/loss
values.
also loss.png
will be created as a figure visulizing main/loss
and validation/main/loss
values.
Returns:
Type | Description |
---|---|
list[str] |
Base keys to plot during training. |
__init__(self)
special
¶
add_arguments(parser)
staticmethod
¶
calculate_all_attentions(self, *args, **kwargs)
¶
Calculate TTS attention weights.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
Tensor |
Batch of attention weights (B, Lmax, Tmax). |
required |
forward(self, *args, **kwargs)
¶
Calculate TTS forward propagation.
Returns:
Type | Description |
---|---|
Tensor |
Loss value. |
inference(self, *args, **kwargs)
¶
Generate the sequence of features given the sequences of characters.
Returns:
Type | Description |
---|---|
Tensor |
The sequence of generated features (L, odim). Tensor: The sequence of stop probabilities (L,). Tensor: The sequence of attention weights (L, T). |
Source code in adviser/tools/espnet_minimal/nets/tts_interface.py
def inference(self, *args, **kwargs):
"""Generate the sequence of features given the sequences of characters.
Returns:
Tensor: The sequence of generated features (L, odim).
Tensor: The sequence of stop probabilities (L,).
Tensor: The sequence of attention weights (L, T).
"""
raise NotImplementedError("inference method is not implemented")
load_pretrained_model(self, model_path)
¶
utils
special
¶
check_kwargs
¶
check_kwargs(func, kwargs, name=None)
¶
check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default :param function func: function to be validated :param dict kwargs: keyword arguments for func :param str name: name used in TypeError (default is func name)
Source code in adviser/tools/espnet_minimal/utils/check_kwargs.py
def check_kwargs(func, kwargs, name=None):
"""check kwargs are valid for func
If kwargs are invalid, raise TypeError as same as python default
:param function func: function to be validated
:param dict kwargs: keyword arguments for func
:param str name: name used in TypeError (default is func name)
"""
try:
params = inspect.signature(func).parameters
except ValueError:
return
if name is None:
name = func.__name__
for k in kwargs.keys():
if k not in params:
raise TypeError(f"{name}() got an unexpected keyword argument '{k}'")
cli_readers
¶
HDF5Reader
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
class HDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
.format(self.rspecifier))
self.rspecifier = rspecifier
self.ark_or_scp, self.filepath = self.rspecifier.split(':', 1)
if self.ark_or_scp not in ['ark', 'scp']:
raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == 'scp':
hdf5_dict = {}
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ':' not in value:
raise RuntimeError(
'scp file for hdf5 should be like: '
'"uttid filepath.h5:key": {}({})'
.format(line, self.filepath))
path, h5_key = value.split(':', 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = h5py.File(path, 'r')
except Exception:
logging.error(
'Error when loading {}'.format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error('Error when loading {} with key={}'
.format(path, h5_key))
raise
if self.return_shape:
yield key, data.shape
else:
yield key, data[()]
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == '-':
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
with h5py.File(filepath, 'r') as f:
for key in f:
if self.return_shape:
yield key, f[key].shape
else:
yield key, f[key][()]
__init__(self, rspecifier, return_shape=False)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
.format(self.rspecifier))
self.rspecifier = rspecifier
self.ark_or_scp, self.filepath = self.rspecifier.split(':', 1)
if self.ark_or_scp not in ['ark', 'scp']:
raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')
self.return_shape = return_shape
__iter__(self)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __iter__(self):
if self.ark_or_scp == 'scp':
hdf5_dict = {}
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ':' not in value:
raise RuntimeError(
'scp file for hdf5 should be like: '
'"uttid filepath.h5:key": {}({})'
.format(line, self.filepath))
path, h5_key = value.split(':', 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = h5py.File(path, 'r')
except Exception:
logging.error(
'Error when loading {}'.format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error('Error when loading {} with key={}'
.format(path, h5_key))
raise
if self.return_shape:
yield key, data.shape
else:
yield key, data[()]
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == '-':
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
with h5py.File(filepath, 'r') as f:
for key in f:
if self.return_shape:
yield key, f[key].shape
else:
yield key, f[key][()]
KaldiReader
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
class KaldiReader:
def __init__(self, rspecifier, return_shape=False, segments=None):
self.rspecifier = rspecifier
self.return_shape = return_shape
self.segments = segments
def __iter__(self):
with kaldiio.ReadHelper(
self.rspecifier, segments=self.segments) as reader:
for key, array in reader:
if self.return_shape:
array = array.shape
yield key, array
SoundHDF5Reader
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
class SoundHDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
.format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
if self.ark_or_scp not in ['ark', 'scp']:
raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == 'scp':
hdf5_dict = {}
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ':' not in value:
raise RuntimeError(
'scp file for hdf5 should be like: '
'"uttid filepath.h5:key": {}({})'
.format(line, self.filepath))
path, h5_key = value.split(':', 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = SoundHDF5File(path, 'r')
except Exception:
logging.error(
'Error when loading {}'.format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error('Error when loading {} with key={}'
.format(path, h5_key))
raise
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
array, rate = data
if self.return_shape:
array = array.shape
yield key, (rate, array)
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == '-':
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
for key, (a, r) in SoundHDF5File(filepath, 'r').items():
if self.return_shape:
a = a.shape
yield key, (r, a)
__init__(self, rspecifier, return_shape=False)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'
.format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
if self.ark_or_scp not in ['ark', 'scp']:
raise ValueError(f'Must be scp or ark: {self.ark_or_scp}')
self.return_shape = return_shape
__iter__(self)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __iter__(self):
if self.ark_or_scp == 'scp':
hdf5_dict = {}
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ':' not in value:
raise RuntimeError(
'scp file for hdf5 should be like: '
'"uttid filepath.h5:key": {}({})'
.format(line, self.filepath))
path, h5_key = value.split(':', 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = SoundHDF5File(path, 'r')
except Exception:
logging.error(
'Error when loading {}'.format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error('Error when loading {} with key={}'
.format(path, h5_key))
raise
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
array, rate = data
if self.return_shape:
array = array.shape
yield key, (rate, array)
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == '-':
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
for key, (a, r) in SoundHDF5File(filepath, 'r').items():
if self.return_shape:
a = a.shape
yield key, (r, a)
SoundReader
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
class SoundReader:
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'
.format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
if self.ark_or_scp != 'scp':
raise ValueError('Only supporting "scp" for sound file: {}'
.format(self.ark_or_scp))
self.return_shape = return_shape
def __iter__(self):
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, sound_file_path = line.rstrip().split(None, 1)
# Assume PCM16
array, rate = soundfile.read(sound_file_path, dtype='int16')
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
if self.return_shape:
array = array.shape
yield key, (rate, array)
__init__(self, rspecifier, return_shape=False)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __init__(self, rspecifier, return_shape=False):
if ':' not in rspecifier:
raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'
.format(rspecifier))
self.ark_or_scp, self.filepath = rspecifier.split(':', 1)
if self.ark_or_scp != 'scp':
raise ValueError('Only supporting "scp" for sound file: {}'
.format(self.ark_or_scp))
self.return_shape = return_shape
__iter__(self)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def __iter__(self):
with open(self.filepath, 'r', encoding='utf-8') as f:
for line in f:
key, sound_file_path = line.rstrip().split(None, 1)
# Assume PCM16
array, rate = soundfile.read(sound_file_path, dtype='int16')
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
if self.return_shape:
array = array.shape
yield key, (rate, array)
file_reader_helper(rspecifier, filetype='mat', return_shape=False, segments=None)
¶
Read uttid and array in kaldi style
This function might be a bit confusing as "ark" is used for HDF5 to imitate "kaldi-rspecifier".
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rspecifier |
str |
Give as "ark:feats.ark" or "scp:feats.scp" |
required |
filetype |
str |
"mat" is kaldi-martix, "hdf5": HDF5 |
'mat' |
return_shape |
bool |
Return the shape of the matrix, instead of the matrix. This can reduce IO cost for HDF5. |
False |
Returns:
Type | Description |
---|---|
Generator[Tuple[str, np.ndarray], None, None] |
Examples:
Read from kaldi-matrix ark file:
Read from HDF5 file:
Source code in adviser/tools/espnet_minimal/utils/cli_readers.py
def file_reader_helper(rspecifier: str, filetype: str = 'mat',
return_shape: bool = False,
segments: str = None):
"""Read uttid and array in kaldi style
This function might be a bit confusing as "ark" is used
for HDF5 to imitate "kaldi-rspecifier".
Args:
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
filetype: "mat" is kaldi-martix, "hdf5": HDF5
return_shape: Return the shape of the matrix,
instead of the matrix. This can reduce IO cost for HDF5.
Returns:
Generator[Tuple[str, np.ndarray], None, None]:
Examples:
Read from kaldi-matrix ark file:
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
... array
Read from HDF5 file:
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
... array
"""
if filetype == 'mat':
return KaldiReader(rspecifier, return_shape=return_shape,
segments=segments)
elif filetype == 'hdf5':
return HDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == 'sound.hdf5':
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == 'sound':
return SoundReader(rspecifier, return_shape=return_shape)
else:
raise NotImplementedError(f'filetype={filetype}')
cli_utils
¶
assert_scipy_wav_style(value)
¶
Source code in adviser/tools/espnet_minimal/utils/cli_utils.py
get_commandline_args()
¶
Source code in adviser/tools/espnet_minimal/utils/cli_utils.py
def get_commandline_args():
extra_chars = [' ', ';', '&', '(', ')', '|', '^', '<', '>', '?', '*',
'[', ']', '$', '`', '"', '\\', '!', '{', '}']
# Escape the extra characters for shell
argv = [arg.replace('\'', '\'\\\'\'')
if all(char not in arg for char in extra_chars)
else '\'' + arg.replace('\'', '\'\\\'\'') + '\''
for arg in sys.argv]
return sys.executable + ' ' + ' '.join(argv)
is_scipy_wav_style(value)
¶
strtobool(x)
¶
cli_writers
¶
BaseWriter
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
class BaseWriter:
def __setitem__(self, key, value):
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
try:
self.writer.close()
except Exception:
pass
if self.writer_scp is not None:
try:
self.writer_scp.close()
except Exception:
pass
if self.writer_nframe is not None:
try:
self.writer_nframe.close()
except Exception:
pass
HDF5Writer (BaseWriter)
¶
HDF5Writer
Examples:
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
class HDF5Writer(BaseWriter):
"""HDF5Writer
Examples:
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
... f['key'] = array
"""
def __init__(self, wspecifier, write_num_frames=None, compress=False):
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict['ark']
if compress:
self.kwargs = {'compression': 'gzip'}
else:
self.kwargs = {}
self.writer = h5py.File(spec_dict['ark'], 'w')
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer.create_dataset(key, data=value, **self.kwargs)
if self.writer_scp is not None:
self.writer_scp.write(f'{key} {self.filename}:{key}\n')
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(value)}\n')
__init__(self, wspecifier, write_num_frames=None, compress=False)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __init__(self, wspecifier, write_num_frames=None, compress=False):
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict['ark']
if compress:
self.kwargs = {'compression': 'gzip'}
else:
self.kwargs = {}
self.writer = h5py.File(spec_dict['ark'], 'w')
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
__setitem__(self, key, value)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
KaldiWriter (BaseWriter)
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
class KaldiWriter(BaseWriter):
def __init__(self, wspecifier, write_num_frames=None, compress=False,
compression_method=2):
if compress:
self.writer = kaldiio.WriteHelper(
wspecifier, compression_method=compression_method)
else:
self.writer = kaldiio.WriteHelper(wspecifier)
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer[key] = value
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(value)}\n')
__init__(self, wspecifier, write_num_frames=None, compress=False, compression_method=2)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __init__(self, wspecifier, write_num_frames=None, compress=False,
compression_method=2):
if compress:
self.writer = kaldiio.WriteHelper(
wspecifier, compression_method=compression_method)
else:
self.writer = kaldiio.WriteHelper(wspecifier)
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
__setitem__(self, key, value)
special
¶
SoundHDF5Writer (BaseWriter)
¶
SoundHDF5Writer
Examples:
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
class SoundHDF5Writer(BaseWriter):
"""SoundHDF5Writer
Examples:
>>> fs = 16000
>>> with SoundHDF5Writer('ark:out.h5') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format='wav'):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict['ark']
self.writer = SoundHDF5File(spec_dict['ark'], 'w',
format=self.pcm_format)
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
# (scipy style -> soundfile style)
value = (value[1], value[0])
self.writer.create_dataset(key, data=value)
if self.writer_scp is not None:
self.writer_scp.write(f'{key} {self.filename}:{key}\n')
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(value[0])}\n')
__init__(self, wspecifier, write_num_frames=None, pcm_format='wav')
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __init__(self, wspecifier, write_num_frames=None, pcm_format='wav'):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict['ark']
self.writer = SoundHDF5File(spec_dict['ark'], 'w',
format=self.pcm_format)
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
__setitem__(self, key, value)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
# (scipy style -> soundfile style)
value = (value[1], value[0])
self.writer.create_dataset(key, data=value)
if self.writer_scp is not None:
self.writer_scp.write(f'{key} {self.filename}:{key}\n')
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(value[0])}\n')
SoundWriter (BaseWriter)
¶
SoundWriter
Examples:
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
class SoundWriter(BaseWriter):
"""SoundWriter
Examples:
>>> fs = 16000
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format='wav'):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
# e.g. ark,scp:dirname,wav.scp
# -> The wave files are found in dirname/*.wav
self.dirname = spec_dict['ark']
Path(self.dirname).mkdir(parents=True, exist_ok=True)
self.writer = None
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
rate, signal = value
wavfile = Path(self.dirname) / (key + '.' + self.pcm_format)
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
if self.writer_scp is not None:
self.writer_scp.write(f'{key} {wavfile}\n')
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(signal)}\n')
__init__(self, wspecifier, write_num_frames=None, pcm_format='wav')
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __init__(self, wspecifier, write_num_frames=None, pcm_format='wav'):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
# e.g. ark,scp:dirname,wav.scp
# -> The wave files are found in dirname/*.wav
self.dirname = spec_dict['ark']
Path(self.dirname).mkdir(parents=True, exist_ok=True)
self.writer = None
if 'scp' in spec_dict:
self.writer_scp = open(spec_dict['scp'], 'w', encoding='utf-8')
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
__setitem__(self, key, value)
special
¶
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
rate, signal = value
wavfile = Path(self.dirname) / (key + '.' + self.pcm_format)
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
if self.writer_scp is not None:
self.writer_scp.write(f'{key} {wavfile}\n')
if self.writer_nframe is not None:
self.writer_nframe.write(f'{key} {len(signal)}\n')
file_writer_helper(wspecifier, filetype='mat', write_num_frames=None, compress=False, compression_method=2, pcm_format='wav')
¶
Write matrices in kaldi style
Parameters:
Name | Type | Description | Default |
---|---|---|---|
wspecifier |
str |
e.g. ark,scp:out.ark,out.scp |
required |
filetype |
str |
"mat" is kaldi-martix, "hdf5": HDF5 |
'mat' |
write_num_frames |
str |
e.g. 'ark,t:num_frames.txt' |
None |
compress |
bool |
Compress or not |
False |
compression_method |
int |
Specify compression level |
2 |
Write in kaldi-matrix-ark with "kaldi-scp" file:
with file_writer_helper('ark,scp:out.ark,out.scp') as f: f['uttid'] = array
This "scp" has the following format:
uttidA out.ark:1234
uttidB out.ark:2222
where, 1234 and 2222 points the strating byte address of the matrix. (For detail, see official documentation of Kaldi)
Write in HDF5 with "scp" file:
with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f: f['uttid'] = array
This "scp" file is created as:
uttidA out.h5:uttidA
uttidB out.h5:uttidB
HDF5 can be, unlike "kaldi-ark", accessed to any keys, so originally "scp" is not required for random-reading. Nevertheless we create "scp" for HDF5 because it is useful for some use-case. e.g. Concatenation, Splitting.
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def file_writer_helper(wspecifier: str, filetype: str = 'mat',
write_num_frames: str = None,
compress: bool = False,
compression_method: int = 2,
pcm_format: str = 'wav'):
"""Write matrices in kaldi style
Args:
wspecifier: e.g. ark,scp:out.ark,out.scp
filetype: "mat" is kaldi-martix, "hdf5": HDF5
write_num_frames: e.g. 'ark,t:num_frames.txt'
compress: Compress or not
compression_method: Specify compression level
Write in kaldi-matrix-ark with "kaldi-scp" file:
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
>>> f['uttid'] = array
This "scp" has the following format:
uttidA out.ark:1234
uttidB out.ark:2222
where, 1234 and 2222 points the strating byte address of the matrix.
(For detail, see official documentation of Kaldi)
Write in HDF5 with "scp" file:
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
>>> f['uttid'] = array
This "scp" file is created as:
uttidA out.h5:uttidA
uttidB out.h5:uttidB
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
so originally "scp" is not required for random-reading.
Nevertheless we create "scp" for HDF5 because it is useful
for some use-case. e.g. Concatenation, Splitting.
"""
if filetype == 'mat':
return KaldiWriter(wspecifier, write_num_frames=write_num_frames,
compress=compress,
compression_method=compression_method)
elif filetype == 'hdf5':
return HDF5Writer(wspecifier, write_num_frames=write_num_frames,
compress=compress)
elif filetype == 'sound.hdf5':
return SoundHDF5Writer(wspecifier, write_num_frames=write_num_frames,
pcm_format=pcm_format)
elif filetype == 'sound':
return SoundWriter(wspecifier, write_num_frames=write_num_frames,
pcm_format=pcm_format)
else:
raise NotImplementedError(f'filetype={filetype}')
get_num_frames_writer(write_num_frames)
¶
get_num_frames_writer
Examples:
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def get_num_frames_writer(write_num_frames: str):
"""get_num_frames_writer
Examples:
>>> get_num_frames_writer('ark,t:num_frames.txt')
"""
if write_num_frames is not None:
if ':' not in write_num_frames:
raise ValueError('Must include ":", write_num_frames={}'
.format(write_num_frames))
nframes_type, nframes_file = write_num_frames.split(':', 1)
if nframes_type != 'ark,t':
raise ValueError(
'Only supporting text mode. '
'e.g. --write-num-frames=ark,t:foo.txt :'
'{}'.format(nframes_type))
return open(nframes_file, 'w', encoding='utf-8')
parse_wspecifier(wspecifier)
¶
Parse wspecifier to dict
Examples:
Source code in adviser/tools/espnet_minimal/utils/cli_writers.py
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
"""Parse wspecifier to dict
Examples:
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
{'ark': 'out.ark', 'scp': 'out.scp'}
"""
ark_scp, filepath = wspecifier.split(':', 1)
if ark_scp not in ['ark', 'scp,ark', 'ark,scp']:
raise ValueError(
'{} is not allowed: {}'.format(ark_scp, wspecifier))
ark_scps = ark_scp.split(',')
filepaths = filepath.split(',')
if len(ark_scps) != len(filepaths):
raise ValueError(
'Mismatch: {} and {}'.format(ark_scp, filepath))
spec_dict = dict(zip(ark_scps, filepaths))
return spec_dict
dataset
¶
This module contains pytorch dataset and dataloader implementation for chainer training.
ChainerDataLoader
¶
Pytorch dataloader in chainer style.
Source code in adviser/tools/espnet_minimal/utils/dataset.py
class ChainerDataLoader(object):
"""Pytorch dataloader in chainer style.
Args:
all args for torch.utils.data.dataloader.Dataloader
"""
def __init__(self, **kwargs):
"""Init function."""
self.loader = torch.utils.data.dataloader.DataLoader(**kwargs)
self.len = len(kwargs['dataset'])
self.current_position = 0
self.epoch = 0
self.iter = None
self.kwargs = kwargs
def next(self):
"""Implement next function."""
if self.iter is None:
self.iter = iter(self.loader)
try:
ret = next(self.iter)
except StopIteration:
self.iter = None
return self.next()
self.current_position += 1
if self.current_position == self.len:
self.epoch = self.epoch + 1
self.current_position = 0
return ret
def __iter__(self):
"""Implement iter function."""
for batch in self.loader:
yield batch
@property
def epoch_detail(self):
"""Epoch_detail required by chainer."""
return self.epoch + self.current_position / self.len
def serialize(self, serializer):
"""Serialize and deserialize function."""
epoch = serializer('epoch', self.epoch)
current_position = serializer('current_position', self.current_position)
self.epoch = epoch
self.current_position = current_position
def start_shuffle(self):
"""Shuffle function for sortagrad."""
self.kwargs['shuffle'] = True
self.loader = torch.utils.data.dataloader.DataLoader(**self.kwargs)
def finalize(self):
"""Implement finalize function."""
del self.loader
epoch_detail
property
readonly
¶
Epoch_detail required by chainer.
__init__(self, **kwargs)
special
¶
Init function.
__iter__(self)
special
¶
finalize(self)
¶
next(self)
¶
Implement next function.
Source code in adviser/tools/espnet_minimal/utils/dataset.py
def next(self):
"""Implement next function."""
if self.iter is None:
self.iter = iter(self.loader)
try:
ret = next(self.iter)
except StopIteration:
self.iter = None
return self.next()
self.current_position += 1
if self.current_position == self.len:
self.epoch = self.epoch + 1
self.current_position = 0
return ret
serialize(self, serializer)
¶
Serialize and deserialize function.
Source code in adviser/tools/espnet_minimal/utils/dataset.py
start_shuffle(self)
¶
TransformDataset (Dataset)
¶
Transform Dataset for pytorch backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data |
list object from make_batchset |
required | |
transfrom |
transform function |
required |
Source code in adviser/tools/espnet_minimal/utils/dataset.py
class TransformDataset(torch.utils.data.Dataset):
"""Transform Dataset for pytorch backend.
Args:
data: list object from make_batchset
transfrom: transform function
"""
def __init__(self, data, transform):
"""Init function."""
super(TransformDataset).__init__()
self.data = data
self.transform = transform
def __len__(self):
"""Len function."""
return len(self.data)
def __getitem__(self, idx):
"""[] operator."""
return self.transform(self.data[idx])
deterministic_utils
¶
set_deterministic_chainer(args)
¶
Ensures chainer produces deterministic results depending on the program arguments
:param Namespace args: The program arguments
Source code in adviser/tools/espnet_minimal/utils/deterministic_utils.py
def set_deterministic_chainer(args):
"""Ensures chainer produces deterministic results depending on the program arguments
:param Namespace args: The program arguments
"""
# seed setting (chainer seed may not need it)
os.environ['CHAINER_SEED'] = str(args.seed)
logging.info('chainer seed = ' + os.environ['CHAINER_SEED'])
# debug mode setting
# 0 would be fastest, but 1 seems to be reasonable
# considering reproducibility
# remove type check
if args.debugmode < 2:
chainer.config.type_check = False
logging.info('chainer type check is disabled')
# use deterministic computation or not
if args.debugmode < 1:
chainer.config.cudnn_deterministic = False
logging.info('chainer cudnn deterministic is disabled')
else:
chainer.config.cudnn_deterministic = True
set_deterministic_pytorch(args)
¶
Ensures pytorch produces deterministic results depending on the program arguments
:param Namespace args: The program arguments
Source code in adviser/tools/espnet_minimal/utils/deterministic_utils.py
def set_deterministic_pytorch(args):
"""Ensures pytorch produces deterministic results depending on the program arguments
:param Namespace args: The program arguments
"""
# seed setting
torch.manual_seed(args.seed)
# debug mode setting
# 0 would be fastest, but 1 seems to be reasonable
# considering reproducibility
# remove type check
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # https://github.com/pytorch/pytorch/issues/6351
if args.debugmode < 2:
chainer.config.type_check = False
logging.info('torch type check is disabled')
# use deterministic computation or not
if args.debugmode < 1:
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
logging.info('torch cudnn deterministic is disabled')
dynamic_import
¶
dynamic_import(import_path, alias={})
¶
dynamic import module and class
:param str import_path: syntax 'module_name:class_name' e.g., 'services.hci.speech.espnet_minimal.transform.add_deltas:AddDeltas' :param dict alias: shortcut for registered class :return: imported class
Source code in adviser/tools/espnet_minimal/utils/dynamic_import.py
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'services.hci.speech.espnet_minimal.transform.add_deltas:AddDeltas'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ':' not in import_path:
raise ValueError(
'import_path should be one of {} or '
'include ":", e.g. "services.hci.speech.espnet_minimal.transform.add_deltas:AddDeltas" : '
'{}'.format(set(alias), import_path))
if ':' not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(':')
m = importlib.import_module(module_name)
return getattr(m, objname)
fill_missing_args
¶
fill_missing_args(args, add_arguments)
¶
Fill missing arguments in args.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args |
Namespace or None |
Namesapce containing hyperparameters. |
required |
add_arguments |
function |
Function to add arguments. |
required |
Returns:
Type | Description |
---|---|
Namespace |
Arguments whose missing ones are filled with default value. |
Examples:
>>> from argparse import Namespace
>>> from services.hci.speech.espnet_minimal.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
>>> args = Namespace()
>>> fill_missing_args(args, Tacotron2.add_arguments_fn)
Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...)
Source code in adviser/tools/espnet_minimal/utils/fill_missing_args.py
def fill_missing_args(args, add_arguments):
"""Fill missing arguments in args.
Args:
args (Namespace or None): Namesapce containing hyperparameters.
add_arguments (function): Function to add arguments.
Returns:
Namespace: Arguments whose missing ones are filled with default value.
Examples:
>>> from argparse import Namespace
>>> from services.hci.speech.espnet_minimal.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2
>>> args = Namespace()
>>> fill_missing_args(args, Tacotron2.add_arguments_fn)
Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...)
"""
# check argument type
assert isinstance(args, argparse.Namespace) or args is None
assert callable(add_arguments)
# get default arguments
default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args()
# convert to dict
args = {} if args is None else vars(args)
default_args = vars(default_args)
for key, value in default_args.items():
if key not in args:
logging.info("attribute \"%s\" does not exist. use default %s." % (key, str(value)))
args[key] = value
# Note from Florian:
# I believe this is where the wrong
# arguments are introduced... no idea
# however where the arguments we actually
# load go missing.
return argparse.Namespace(**args)
io_utils
¶
LoadInputsAndTargets
¶
Create a mini-batch from a list of dicts
batch = [('utt1', ... dict(input=[dict(feat='some.ark:123', ... filetype='mat', ... name='input1', ... shape=[100, 80])], ... output=[dict(tokenid='1 2 3 4', ... name='target1', ... shape=[4, 31])]])) l = LoadInputsAndTargets() feat, target = l(batch)
:param: str mode: Specify the task mode, "asr" or "tts" :param: str preprocess_conf: The path of a json file for pre-processing :param: bool load_input: If False, not to load the input data :param: bool load_output: If False, not to load the output data :param: bool sort_in_input_length: Sort the mini-batch in descending order of the input length :param: bool use_speaker_embedding: Used for tts mode only :param: bool use_second_target: Used for tts mode only :param: dict preprocess_args: Set some optional arguments for preprocessing :param: Optional[dict] preprocess_args: Used for tts mode only
Source code in adviser/tools/espnet_minimal/utils/io_utils.py
class LoadInputsAndTargets(object):
"""Create a mini-batch from a list of dicts
>>> batch = [('utt1',
... dict(input=[dict(feat='some.ark:123',
... filetype='mat',
... name='input1',
... shape=[100, 80])],
... output=[dict(tokenid='1 2 3 4',
... name='target1',
... shape=[4, 31])]]))
>>> l = LoadInputsAndTargets()
>>> feat, target = l(batch)
:param: str mode: Specify the task mode, "asr" or "tts"
:param: str preprocess_conf: The path of a json file for pre-processing
:param: bool load_input: If False, not to load the input data
:param: bool load_output: If False, not to load the output data
:param: bool sort_in_input_length: Sort the mini-batch in descending order
of the input length
:param: bool use_speaker_embedding: Used for tts mode only
:param: bool use_second_target: Used for tts mode only
:param: dict preprocess_args: Set some optional arguments for preprocessing
:param: Optional[dict] preprocess_args: Used for tts mode only
"""
def __init__(self, mode='asr',
preprocess_conf=None,
load_input=True,
load_output=True,
sort_in_input_length=True,
use_speaker_embedding=False,
use_second_target=False,
preprocess_args=None,
keep_all_data_on_mem=False,
):
self._loaders = {}
if mode not in ['asr', 'tts', 'mt']:
raise ValueError(
'Only asr or tts are allowed: mode={}'.format(mode))
if preprocess_conf is not None:
self.preprocessing = Transformation(preprocess_conf)
logging.warning(
'[Experimental feature] Some preprocessing will be done '
'for the mini-batch creation using {}'
.format(self.preprocessing))
else:
# If conf doesn't exist, this function don't touch anything.
self.preprocessing = None
if use_second_target and use_speaker_embedding and mode == 'tts':
raise ValueError('Choose one of "use_second_target" and '
'"use_speaker_embedding "')
if (use_second_target or use_speaker_embedding) and mode != 'tts':
logging.warning(
'"use_second_target" and "use_speaker_embedding" is '
'used only for tts mode')
self.mode = mode
self.load_output = load_output
self.load_input = load_input
self.sort_in_input_length = sort_in_input_length
self.use_speaker_embedding = use_speaker_embedding
self.use_second_target = use_second_target
if preprocess_args is None:
self.preprocess_args = {}
else:
assert isinstance(preprocess_args, dict), type(preprocess_args)
self.preprocess_args = dict(preprocess_args)
self.keep_all_data_on_mem = keep_all_data_on_mem
def __call__(self, batch):
"""Function to load inputs and targets from list of dicts
:param List[Tuple[str, dict]] batch: list of dict which is subset of
loaded data.json
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
:return: list of input feature sequences
[(T_1, D), (T_2, D), ..., (T_B, D)]
:rtype: list of float ndarray
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
:rtype: list of int ndarray
"""
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
uttid_list = [] # List[str]
for uttid, info in batch:
uttid_list.append(uttid)
if self.load_input:
# Note(kamo): This for-loop is for multiple inputs
for idx, inp in enumerate(info['input']):
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "input1", ...}], ...}
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
x_feats_dict.setdefault(inp['name'], []).append(x)
# FIXME(kamo): Dirty way to load only speaker_embedding without the other inputs
elif self.mode == 'tts' and self.use_speaker_embedding:
for idx, inp in enumerate(info['input']):
if idx != 1 and len(info['input']) > 1:
x = None
else:
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
x_feats_dict.setdefault(inp['name'], []).append(x)
if self.load_output:
if self.mode == 'mt':
x = np.fromiter(map(int, info['output'][1]['tokenid'].split()),
dtype=np.int64)
x_feats_dict.setdefault(info['output'][1]['name'], []).append(x)
for idx, inp in enumerate(info['output']):
if 'tokenid' in inp:
# ======= Legacy format for output =======
# {"output": [{"tokenid": "1 2 3 4"}])
x = np.fromiter(map(int, inp['tokenid'].split()),
dtype=np.int64)
else:
# ======= New format =======
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "target1", ...}], ...}
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
y_feats_dict.setdefault(inp['name'], []).append(x)
if self.mode == 'asr':
return_batch, uttid_list = self._create_batch_asr(
x_feats_dict, y_feats_dict, uttid_list)
elif self.mode == 'tts':
_, info = batch[0]
eos = int(info['output'][0]['shape'][1]) - 1
return_batch, uttid_list = self._create_batch_tts(
x_feats_dict, y_feats_dict, uttid_list, eos)
elif self.mode == 'mt':
return_batch, uttid_list = self._create_batch_mt(
x_feats_dict, y_feats_dict, uttid_list)
else:
raise NotImplementedError
if self.preprocessing is not None:
# Apply pre-processing all input features
for x_name in return_batch.keys():
if x_name.startswith("input"):
return_batch[x_name] = self.preprocessing(
return_batch[x_name], uttid_list, **self.preprocess_args)
# Doesn't return the names now.
return tuple(return_batch.values())
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
"""Create a OrderedDict for the mini-batch
:param OrderedDict x_feats_dict:
e.g. {"input1": [ndarray, ndarray, ...],
"input2": [ndarray, ndarray, ...]}
:param OrderedDict y_feats_dict:
e.g. {"target1": [ndarray, ndarray, ...],
"target2": [ndarray, ndarray, ...]}
:param: List[str] uttid_list:
Give uttid_list to sort in the same order as the mini-batch
:return: batch, uttid_list
:rtype: Tuple[OrderedDict, List[str]]
"""
# handle single-input and multi-input (paralell) asr mode
xs = list(x_feats_dict.values())
if self.load_output:
if len(y_feats_dict) == 1:
ys = list(y_feats_dict.values())[0]
assert len(xs[0]) == len(ys), (len(xs[0]), len(ys))
# get index of non-zero length samples
nonzero_idx = list(filter(lambda i: len(ys[i]) > 0, range(len(ys))))
elif len(y_feats_dict) > 1: # multi-speaker asr mode
ys = list(y_feats_dict.values())
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
# get index of non-zero length samples
nonzero_idx = list(filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
for n in range(1, len(y_feats_dict)):
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
else:
# Note(kamo): Be careful not to make nonzero_idx to a generator
nonzero_idx = list(range(len(xs[0])))
if self.sort_in_input_length:
# sort in input lengths based on the first input
nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[0][i]))
else:
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs[0]):
logging.warning(
'Target sequences include empty tokenid (batch {} -> {}).'
.format(len(xs[0]), len(nonzero_sorted_idx)))
# remove zero-length samples
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
x_names = list(x_feats_dict.keys())
if self.load_output:
if len(y_feats_dict) == 1:
ys = [ys[i] for i in nonzero_sorted_idx]
elif len(y_feats_dict) > 1: # multi-speaker asr mode
ys = zip(*[[y[i] for i in nonzero_sorted_idx] for y in ys])
y_name = list(y_feats_dict.keys())[0]
# Keeping x_name and y_name, e.g. input1, for future extension
return_batch = OrderedDict([*[(x_name, x) for x_name, x in zip(x_names, xs)], (y_name, ys)])
else:
return_batch = OrderedDict([(x_name, x) for x_name, x in zip(x_names, xs)])
return return_batch, uttid_list
def _create_batch_mt(self, x_feats_dict, y_feats_dict, uttid_list):
"""Create a OrderedDict for the mini-batch
:param OrderedDict x_feats_dict:
:param OrderedDict y_feats_dict:
:return: batch, uttid_list
:rtype: Tuple[OrderedDict, List[str]]
"""
# Create a list from the first item
xs = list(x_feats_dict.values())[0]
if self.load_output:
ys = list(y_feats_dict.values())[0]
assert len(xs) == len(ys), (len(xs), len(ys))
# get index of non-zero length samples
nonzero_idx = filter(lambda i: len(ys[i]) > 0, range(len(ys)))
else:
nonzero_idx = range(len(xs))
if self.sort_in_input_length:
# sort in input lengths
nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[i]))
else:
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs):
logging.warning(
'Target sequences include empty tokenid (batch {} -> {}).'
.format(len(xs), len(nonzero_sorted_idx)))
# remove zero-length samples
xs = [xs[i] for i in nonzero_sorted_idx]
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
x_name = list(x_feats_dict.keys())[0]
if self.load_output:
ys = [ys[i] for i in nonzero_sorted_idx]
y_name = list(y_feats_dict.keys())[0]
return_batch = OrderedDict([(x_name, xs), (y_name, ys)])
else:
return_batch = OrderedDict([(x_name, xs)])
return return_batch, uttid_list
def _create_batch_tts(self, x_feats_dict, y_feats_dict, uttid_list, eos):
"""Create a OrderedDict for the mini-batch
:param OrderedDict x_feats_dict:
e.g. {"input1": [ndarray, ndarray, ...],
"input2": [ndarray, ndarray, ...]}
:param OrderedDict y_feats_dict:
e.g. {"target1": [ndarray, ndarray, ...],
"target2": [ndarray, ndarray, ...]}
:param: List[str] uttid_list:
:param int eos:
:return: batch, uttid_list
:rtype: Tuple[OrderedDict, List[str]]
"""
# Use the output values as the input feats for tts mode
xs = list(y_feats_dict.values())[0]
# get index of non-zero length samples
nonzero_idx = list(filter(lambda i: len(xs[i]) > 0, range(len(xs))))
# sort in input lengths
if self.sort_in_input_length:
# sort in input lengths
nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[i]))
else:
nonzero_sorted_idx = nonzero_idx
# remove zero-length samples
xs = [xs[i] for i in nonzero_sorted_idx]
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
# Added eos into input sequence
xs = [np.append(x, eos) for x in xs]
if self.load_input:
ys = list(x_feats_dict.values())[0]
assert len(xs) == len(ys), (len(xs), len(ys))
ys = [ys[i] for i in nonzero_sorted_idx]
spembs = None
spcs = None
spembs_name = 'spembs_none'
spcs_name = 'spcs_none'
if self.use_second_target:
spcs = list(x_feats_dict.values())[1]
spcs = [spcs[i] for i in nonzero_sorted_idx]
spcs_name = list(x_feats_dict.keys())[1]
if self.use_speaker_embedding:
spembs = list(x_feats_dict.values())[1]
spembs = [spembs[i] for i in nonzero_sorted_idx]
spembs_name = list(x_feats_dict.keys())[1]
x_name = list(y_feats_dict.keys())[0]
y_name = list(x_feats_dict.keys())[0]
return_batch = OrderedDict([(x_name, xs),
(y_name, ys),
(spembs_name, spembs),
(spcs_name, spcs)])
elif self.use_speaker_embedding:
if len(x_feats_dict) == 0:
raise IndexError('No speaker embedding is provided')
elif len(x_feats_dict) == 1:
spembs_idx = 0
else:
spembs_idx = 1
spembs = list(x_feats_dict.values())[spembs_idx]
spembs = [spembs[i] for i in nonzero_sorted_idx]
x_name = list(y_feats_dict.keys())[0]
spembs_name = list(x_feats_dict.keys())[spembs_idx]
return_batch = OrderedDict([(x_name, xs),
(spembs_name, spembs)])
else:
x_name = list(y_feats_dict.keys())[0]
return_batch = OrderedDict([(x_name, xs)])
return return_batch, uttid_list
def _get_from_loader(self, filepath, filetype):
"""Return ndarray
In order to make the fds to be opened only at the first referring,
the loader are stored in self._loaders
>>> ndarray = loader.get_from_loader(
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
:param: str filepath:
:param: str filetype:
:return:
:rtype: np.ndarray
"""
if filetype == 'hdf5':
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(':', 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = h5py.File(filepath, 'r')
self._loaders[filepath] = loader
return loader[key][()]
elif filetype == 'sound.hdf5':
# e.g.
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "sound.hdf5",
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
filepath, key = filepath.split(':', 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = SoundHDF5File(filepath, 'r', dtype='int16')
self._loaders[filepath] = loader
array, rate = loader[key]
return array
elif filetype == 'sound':
# e.g.
# {"input": [{"feat": "some/path.wav",
# "filetype": "sound"},
# Assume PCM16
if not self.keep_all_data_on_mem:
array, _ = soundfile.read(filepath, dtype='int16')
return array
if filepath not in self._loaders:
array, _ = soundfile.read(filepath, dtype='int16')
self._loaders[filepath] = array
return self._loaders[filepath]
elif filetype == 'npz':
# e.g.
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
# "filetype": "npz",
filepath, key = filepath.split(':', 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = np.load(filepath)
self._loaders[filepath] = loader
return loader[key]
elif filetype == 'npy':
# e.g.
# {"input": [{"feat": "some/path.npy",
# "filetype": "npy"},
if not self.keep_all_data_on_mem:
return np.load(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = np.load(filepath)
return self._loaders[filepath]
elif filetype in ['mat', 'vec']:
# e.g.
# {"input": [{"feat": "some/path.ark:123",
# "filetype": "mat"}]},
# In this case, "123" indicates the starting points of the matrix
# load_mat can load both matrix and vector
if not self.keep_all_data_on_mem:
return kaldiio.load_mat(filepath)
if filepath not in self._loaders:
self._loaders[filepath] = kaldiio.load_mat(filepath)
return self._loaders[filepath]
elif filetype == 'scp':
# e.g.
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
# "filetype": "scp",
filepath, key = filepath.split(':', 1)
loader = self._loaders.get(filepath)
if loader is None:
# To avoid disk access, create loader only for the first time
loader = kaldiio.load_scp(filepath)
self._loaders[filepath] = loader
return loader[key]
else:
raise NotImplementedError(
'Not supported: loader_type={}'.format(filetype))
__call__(self, batch)
special
¶
Function to load inputs and targets from list of dicts
:param List[Tuple[str, dict]] batch: list of dict which is subset of loaded data.json :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)] :return: list of input feature sequences [(T_1, D), (T_2, D), ..., (T_B, D)] :rtype: list of float ndarray :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)] :rtype: list of int ndarray
Source code in adviser/tools/espnet_minimal/utils/io_utils.py
def __call__(self, batch):
"""Function to load inputs and targets from list of dicts
:param List[Tuple[str, dict]] batch: list of dict which is subset of
loaded data.json
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
:return: list of input feature sequences
[(T_1, D), (T_2, D), ..., (T_B, D)]
:rtype: list of float ndarray
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
:rtype: list of int ndarray
"""
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
uttid_list = [] # List[str]
for uttid, info in batch:
uttid_list.append(uttid)
if self.load_input:
# Note(kamo): This for-loop is for multiple inputs
for idx, inp in enumerate(info['input']):
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "input1", ...}], ...}
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
x_feats_dict.setdefault(inp['name'], []).append(x)
# FIXME(kamo): Dirty way to load only speaker_embedding without the other inputs
elif self.mode == 'tts' and self.use_speaker_embedding:
for idx, inp in enumerate(info['input']):
if idx != 1 and len(info['input']) > 1:
x = None
else:
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
x_feats_dict.setdefault(inp['name'], []).append(x)
if self.load_output:
if self.mode == 'mt':
x = np.fromiter(map(int, info['output'][1]['tokenid'].split()),
dtype=np.int64)
x_feats_dict.setdefault(info['output'][1]['name'], []).append(x)
for idx, inp in enumerate(info['output']):
if 'tokenid' in inp:
# ======= Legacy format for output =======
# {"output": [{"tokenid": "1 2 3 4"}])
x = np.fromiter(map(int, inp['tokenid'].split()),
dtype=np.int64)
else:
# ======= New format =======
# {"input":
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
# "filetype": "hdf5",
# "name": "target1", ...}], ...}
x = self._get_from_loader(
filepath=inp['feat'],
filetype=inp.get('filetype', 'mat'))
y_feats_dict.setdefault(inp['name'], []).append(x)
if self.mode == 'asr':
return_batch, uttid_list = self._create_batch_asr(
x_feats_dict, y_feats_dict, uttid_list)
elif self.mode == 'tts':
_, info = batch[0]
eos = int(info['output'][0]['shape'][1]) - 1
return_batch, uttid_list = self._create_batch_tts(
x_feats_dict, y_feats_dict, uttid_list, eos)
elif self.mode == 'mt':
return_batch, uttid_list = self._create_batch_mt(
x_feats_dict, y_feats_dict, uttid_list)
else:
raise NotImplementedError
if self.preprocessing is not None:
# Apply pre-processing all input features
for x_name in return_batch.keys():
if x_name.startswith("input"):
return_batch[x_name] = self.preprocessing(
return_batch[x_name], uttid_list, **self.preprocess_args)
# Doesn't return the names now.
return tuple(return_batch.values())
__init__(self, mode='asr', preprocess_conf=None, load_input=True, load_output=True, sort_in_input_length=True, use_speaker_embedding=False, use_second_target=False, preprocess_args=None, keep_all_data_on_mem=False)
special
¶
Source code in adviser/tools/espnet_minimal/utils/io_utils.py
def __init__(self, mode='asr',
preprocess_conf=None,
load_input=True,
load_output=True,
sort_in_input_length=True,
use_speaker_embedding=False,
use_second_target=False,
preprocess_args=None,
keep_all_data_on_mem=False,
):
self._loaders = {}
if mode not in ['asr', 'tts', 'mt']:
raise ValueError(
'Only asr or tts are allowed: mode={}'.format(mode))
if preprocess_conf is not None:
self.preprocessing = Transformation(preprocess_conf)
logging.warning(
'[Experimental feature] Some preprocessing will be done '
'for the mini-batch creation using {}'
.format(self.preprocessing))
else:
# If conf doesn't exist, this function don't touch anything.
self.preprocessing = None
if use_second_target and use_speaker_embedding and mode == 'tts':
raise ValueError('Choose one of "use_second_target" and '
'"use_speaker_embedding "')
if (use_second_target or use_speaker_embedding) and mode != 'tts':
logging.warning(
'"use_second_target" and "use_speaker_embedding" is '
'used only for tts mode')
self.mode = mode
self.load_output = load_output
self.load_input = load_input
self.sort_in_input_length = sort_in_input_length
self.use_speaker_embedding = use_speaker_embedding
self.use_second_target = use_second_target
if preprocess_args is None:
self.preprocess_args = {}
else:
assert isinstance(preprocess_args, dict), type(preprocess_args)
self.preprocess_args = dict(preprocess_args)
self.keep_all_data_on_mem = keep_all_data_on_mem
SoundHDF5File
¶
Collecting sound files to a HDF5 file
f = SoundHDF5File('a.flac.h5', mode='a') array = np.random.randint(0, 100, 100, dtype=np.int16) f['id'] = (array, 16000) array, rate = f['id']
:param: str filepath: :param: str mode: :param: str format: The type used when saving wav. flac, nist, htk, etc. :param: str dtype:
Source code in adviser/tools/espnet_minimal/utils/io_utils.py
class SoundHDF5File(object):
"""Collecting sound files to a HDF5 file
>>> f = SoundHDF5File('a.flac.h5', mode='a')
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
>>> f['id'] = (array, 16000)
>>> array, rate = f['id']
:param: str filepath:
:param: str mode:
:param: str format: The type used when saving wav. flac, nist, htk, etc.
:param: str dtype:
"""
def __init__(self, filepath, mode='r+', format=None, dtype='int16',
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = 'flac'
# This format affects only saving
self.format = format
def __repr__(self):
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>' \
.format(self.filepath, self.mode, self.format, self.dtype)
def create_dataset(self, name, shape=None, data=None, **kwds):
f = io.BytesIO()
array, rate = data
soundfile.write(f, array, rate, format=self.format)
self.file.create_dataset(name, shape=shape,
data=np.void(f.getvalue()), **kwds)
def __setitem__(self, name, data):
self.create_dataset(name, data=data)
def __getitem__(self, key):
data = self.file[key][()]
f = io.BytesIO(data.tobytes())
array, rate = soundfile.read(f, dtype=self.dtype)
return array, rate
def keys(self):
return self.file.keys()
def values(self):
for k in self.file:
yield self[k]
def items(self):
for k in self.file:
yield k, self[k]
def __iter__(self):
return iter(self.file)
def __contains__(self, item):
return item in self.file
def __len__(self, item):
return len(self.file)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def close(self):
self.file.close()
__contains__(self, item)
special
¶
__enter__(self)
special
¶
__exit__(self, exc_type, exc_val, exc_tb)
special
¶
__getitem__(self, key)
special
¶
__init__(self, filepath, mode='r+', format=None, dtype='int16', **kwargs)
special
¶
Source code in adviser/tools/espnet_minimal/utils/io_utils.py
def __init__(self, filepath, mode='r+', format=None, dtype='int16',
**kwargs):
self.filepath = filepath
self.mode = mode
self.dtype = dtype
self.file = h5py.File(filepath, mode, **kwargs)
if format is None:
# filepath = a.flac.h5 -> format = flac
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
format = second_ext[1:]
if format.upper() not in soundfile.available_formats():
# If not found, flac is selected
format = 'flac'
# This format affects only saving
self.format = format
__iter__(self)
special
¶
__len__(self, item)
special
¶
__repr__(self)
special
¶
__setitem__(self, name, data)
special
¶
close(self)
¶
create_dataset(self, name, shape=None, data=None, **kwds)
¶
items(self)
¶
keys(self)
¶
values(self)
¶
spec_augment
¶
This implementation is modified from https://github.com/zcaceres/spec_augment
MIT License
Copyright (c) 2019 Zach Caceres
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETjjHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
apply_interpolation(query_points, train_points, w, v, order)
¶
Apply polyharmonic interpolation model to data.
Notes
Given coefficients w and v for the interpolation model, we evaluate interpolated function values at query_points.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query_points |
|
required | |
train_points |
|
required | |
order |
order of the interpolation |
required |
Returns:
Type | Description |
---|---|
Polyharmonic interpolation evaluated at points defined in query_points. |
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def apply_interpolation(query_points, train_points, w, v, order):
"""Apply polyharmonic interpolation model to data.
Notes:
Given coefficients w and v for the interpolation model, we evaluate
interpolated function values at query_points.
Args:
query_points: `[b, m, d]` x values to evaluate the interpolation at
train_points: `[b, n, d]` x values that act as the interpolation centers
( the c variables in the wikipedia article)
w: `[b, n, k]` weights on each interpolation center
v: `[b, d, k]` weights on each input dimension
order: order of the interpolation
Returns:
Polyharmonic interpolation evaluated at points defined in query_points.
"""
query_points = query_points.unsqueeze(0)
# First, compute the contribution from the rbf term.
pairwise_dists = cross_squared_distance_matrix(query_points.float(), train_points.float())
phi_pairwise_dists = phi(pairwise_dists, order)
rbf_term = torch.matmul(phi_pairwise_dists, w)
# Then, compute the contribution from the linear term.
# Pad query_points with ones, for the bias term in the linear model.
ones = torch.ones_like(query_points[..., :1])
query_points_pad = torch.cat((
query_points,
ones
), 2).float()
linear_term = torch.matmul(query_points_pad, v)
return rbf_term + linear_term
create_dense_flows(flattened_flows, batch_size, image_height, image_width)
¶
cross_squared_distance_matrix(x, y)
¶
Pairwise squared distance between two (batch) matrices' rows (2nd dim).
Computes the pairwise distances between rows of x and rows of y
x: [batch_size, n, d] float Tensor
y: [batch_size, m, d] float Tensor
squared_dists: [batch_size, n, m] float Tensor
, where
squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def cross_squared_distance_matrix(x, y):
"""Pairwise squared distance between two (batch) matrices' rows (2nd dim).
Computes the pairwise distances between rows of x and rows of y
Args:
x: [batch_size, n, d] float `Tensor`
y: [batch_size, m, d] float `Tensor`
Returns:
squared_dists: [batch_size, n, m] float `Tensor`, where
squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
"""
x_norm_squared = torch.sum(torch.mul(x, x))
y_norm_squared = torch.sum(torch.mul(y, y))
x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1))
# squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared
return squared_dists.float()
dense_image_warp(image, flow)
¶
Image warping using per-pixel flow vectors.
Apply a non-linear warp to the image, where the warp is specified by a dense flow field of offset vectors that define the correspondences of pixel values in the output image back to locations in the source image. Specifically, the pixel value at output[b, j, i, c] is images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. The locations specified by this formula do not necessarily map to an int index. Therefore, the pixel value is obtained by bilinear interpolation of the 4 nearest pixels around (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside of the image, we use the nearest pixel values at the image boundary.
image: 4-D float Tensor
with shape [batch, height, width, channels]
.
flow: A 4-D float Tensor
with shape [batch, height, width, 2]
.
name: A name for the operation (optional).
Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
and do not necessarily have to be the same type.
A 4-D float Tensor
with shape[batch, height, width, channels]
and same type as input image.
ValueError: if height < 2 or width < 2 or the inputs have the wrong number of dimensions.
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def dense_image_warp(image, flow):
"""Image warping using per-pixel flow vectors.
Apply a non-linear warp to the image, where the warp is specified by a dense
flow field of offset vectors that define the correspondences of pixel values
in the output image back to locations in the source image. Specifically, the
pixel value at output[b, j, i, c] is
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
The locations specified by this formula do not necessarily map to an int
index. Therefore, the pixel value is obtained by bilinear
interpolation of the 4 nearest pixels around
(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
of the image, we use the nearest pixel values at the image boundary.
Args:
image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
name: A name for the operation (optional).
Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
and do not necessarily have to be the same type.
Returns:
A 4-D float `Tensor` with shape`[batch, height, width, channels]`
and same type as input image.
Raises:
ValueError: if height < 2 or width < 2 or the inputs have the wrong number
of dimensions.
"""
image = image.unsqueeze(3) # add a single channel dimension to image tensor
batch_size, height, width, channels = image.shape
device = image.device
# The flow is defined on the image grid. Turn the flow into a list of query
# points in the grid space.
grid_x, grid_y = torch.meshgrid(
torch.arange(width, device=device), torch.arange(height, device=device))
stacked_grid = torch.stack((grid_y, grid_x), dim=2).float()
batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2)
query_points_on_grid = batched_grid - flow
query_points_flattened = torch.reshape(query_points_on_grid, [batch_size, height * width, 2])
# Compute values at the query points, then reshape the result back to the
# image grid.
interpolated = interpolate_bilinear(image, query_points_flattened)
interpolated = torch.reshape(interpolated, [batch_size, height, width, channels])
return interpolated
flatten_grid_locations(grid_locations, image_height, image_width)
¶
freq_mask(spec, F=30, num_masks=1, replace_with_zero=False)
¶
Frequency masking
:param torch.Tensor spec: input tensor with shape (T, dim) :param int F: maximum width of each mask :param int num_masks: number of masks :param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False):
"""Frequency masking
:param torch.Tensor spec: input tensor with shape (T, dim)
:param int F: maximum width of each mask
:param int num_masks: number of masks
:param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
"""
cloned = spec.unsqueeze(0).clone()
num_mel_channels = cloned.shape[2]
for i in range(0, num_masks):
f = random.randrange(0, F)
f_zero = random.randrange(0, num_mel_channels - f)
# avoids randrange error if values are equal and range is empty
if (f_zero == f_zero + f):
return cloned.squeeze(0)
mask_end = random.randrange(f_zero, f_zero + f)
if (replace_with_zero):
cloned[0][:, f_zero:mask_end] = 0
else:
cloned[0][:, f_zero:mask_end] = cloned.mean()
return cloned.squeeze(0)
get_flat_grid_locations(image_height, image_width, device)
¶
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def get_flat_grid_locations(image_height, image_width, device):
y_range = torch.linspace(0, image_height - 1, image_height, device=device)
x_range = torch.linspace(0, image_width - 1, image_width, device=device)
y_grid, x_grid = torch.meshgrid(y_range, x_range)
return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2])
get_grid_locations(image_height, image_width, device)
¶
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
interpolate_bilinear(grid, query_points, name='interpolate_bilinear', indexing='ij')
¶
Similar to Matlab's interp2 function.
Notes
Finds values for query points on a grid using bilinear interpolation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
grid |
a 4-D float |
required | |
query_points |
a 3-D float |
required | |
name |
a name for the operation (optional). |
'interpolate_bilinear' |
|
indexing |
whether the query points are specified as row and column (ij), or Cartesian coordinates (xy). |
'ij' |
Returns:
Type | Description |
---|---|
values |
a 3-D |
Exceptions:
Type | Description |
---|---|
ValueError |
if the indexing mode is invalid, or if the shape of the inputs |
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def interpolate_bilinear(grid,
query_points,
name='interpolate_bilinear',
indexing='ij'):
"""Similar to Matlab's interp2 function.
Notes:
Finds values for query points on a grid using bilinear interpolation.
Args:
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
name: a name for the operation (optional).
indexing: whether the query points are specified as row and column (ij),
or Cartesian coordinates (xy).
Returns:
values: a 3-D `Tensor` with shape `[batch, N, channels]`
Raises:
ValueError: if the indexing mode is invalid, or if the shape of the inputs
invalid.
"""
if indexing != 'ij' and indexing != 'xy':
raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
shape = grid.shape
if len(shape) != 4:
msg = 'Grid must be 4 dimensional. Received size: '
raise ValueError(msg + str(grid.shape))
batch_size, height, width, channels = grid.shape
shape = [batch_size, height, width, channels]
query_type = query_points.dtype
grid_type = grid.dtype
grid_device = grid.device
num_queries = query_points.shape[1]
alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == 'ij' else [1, 0]
unstacked_query_points = query_points.unbind(2)
for dim in index_order:
queries = unstacked_query_points[dim]
size_in_indexing_dimension = shape[dim + 1]
# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
# is still a valid index into the grid.
max_floor = torch.tensor(size_in_indexing_dimension - 2, dtype=query_type, device=grid_device)
min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device)
maxx = torch.max(min_floor, torch.floor(queries))
floor = torch.min(maxx, max_floor)
int_floor = floor.long()
floors.append(int_floor)
ceil = int_floor + 1
ceils.append(ceil)
# alpha has the same type as the grid, as we will directly use alpha
# when taking linear combinations of pixel values from the image.
alpha = torch.tensor((queries - floor), dtype=grid_type, device=grid_device)
min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device)
max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device)
alpha = torch.min(torch.max(min_alpha, alpha), max_alpha)
# Expand alpha to [b, n, 1] so we can use broadcasting
# (since the alpha values don't depend on the channel).
alpha = torch.unsqueeze(alpha, 2)
alphas.append(alpha)
flattened_grid = torch.reshape(grid, [batch_size * height * width, channels])
batch_offsets = torch.reshape(torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1])
# This wraps array_ops.gather. We reshape the image data such that the
# batch, y, and x coordinates are pulled into the first dimension.
# Then we gather. Finally, we reshape the output back. It's possible this
# code would be made simpler by using array_ops.gather_nd.
def gather(y_coords, x_coords, name):
linear_coordinates = batch_offsets + y_coords * width + x_coords
gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates)
return torch.reshape(gathered_values, [batch_size, num_queries, channels])
# grab the pixel values in the 4 corners around each query point
top_left = gather(floors[0], floors[1], 'top_left')
top_right = gather(floors[0], ceils[1], 'top_right')
bottom_left = gather(ceils[0], floors[1], 'bottom_left')
bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
return interp
interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0)
¶
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def interpolate_spline(train_points, train_values, query_points, order, regularization_weight=0.0, ):
# First, fit the spline to the observed data.
w, v = solve_interpolation(train_points, train_values, order, regularization_weight)
# Then, evaluate the spline at the query locations.
query_values = apply_interpolation(query_points, train_points, w, v, order)
return query_values
phi(r, order)
¶
Coordinate-wise nonlinearity used to define the order of the interpolation.
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
r: input op order: interpolation order
phi_k evaluated coordinate-wise on r, for k = r
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def phi(r, order):
"""Coordinate-wise nonlinearity used to define the order of the interpolation.
See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.
Args:
r: input op
order: interpolation order
Returns:
phi_k evaluated coordinate-wise on r, for k = r
"""
EPSILON = torch.tensor(1e-10, device=r.device)
# using EPSILON prevents log(0), sqrt0), etc.
# sqrt(0) is well-defined, but its gradient is not
if order == 1:
r = torch.max(r, EPSILON)
r = torch.sqrt(r)
return r
elif order == 2:
return 0.5 * r * torch.log(torch.max(r, EPSILON))
elif order == 4:
return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON))
elif order % 2 == 0:
r = torch.max(r, EPSILON)
return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r)
else:
r = torch.max(r, EPSILON)
return torch.pow(r, 0.5 * order)
solve_interpolation(train_points, train_values, order, regularization_weight)
¶
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def solve_interpolation(train_points, train_values, order, regularization_weight):
device = train_points.device
b, n, d = train_points.shape
k = train_values.shape[-1]
c = train_points
f = train_values.float()
matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n]
# Append ones to the feature values for the bias term in the linear model.
ones = torch.ones(1, dtype=train_points.dtype, device=device).view([-1, 1, 1])
matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1]
# [b, n + d + 1, n]
left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1)
num_b_cols = matrix_b.shape[2] # d + 1
# In Tensorflow, zeros are used here. Pytorch solve fails with zeros for some reason we don't understand.
# So instead we use very tiny randn values (variance of one, zero mean) on one side of our multiplication.
lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) / 1e10
right_block = torch.cat((matrix_b, lhs_zeros), 1) # [b, n + d + 1, d + 1]
lhs = torch.cat((left_block, right_block), 2) # [b, n + d + 1, n + d + 1]
rhs_zeros = torch.zeros((b, d + 1, k), dtype=train_points.dtype, device=device).float()
rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k]
# Then, solve the linear system and unpack the results.
X, LU = torch.gesv(rhs, lhs)
w = X[:, :n, :]
v = X[:, n:, :]
return w, v
sparse_image_warp(img_tensor, source_control_point_locations, dest_control_point_locations, interpolation_order=2, regularization_weight=0.0, num_boundaries_points=0)
¶
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def sparse_image_warp(img_tensor,
source_control_point_locations,
dest_control_point_locations,
interpolation_order=2,
regularization_weight=0.0,
num_boundaries_points=0):
device = img_tensor.device
control_point_flows = dest_control_point_locations - source_control_point_locations
batch_size, image_height, image_width = img_tensor.shape
flattened_grid_locations = get_flat_grid_locations(image_height, image_width, device)
flattened_flows = interpolate_spline(
dest_control_point_locations,
control_point_flows,
flattened_grid_locations,
interpolation_order,
regularization_weight)
dense_flows = create_dense_flows(flattened_flows, batch_size, image_height, image_width)
warped_image = dense_image_warp(img_tensor, dense_flows)
return warped_image, dense_flows
specaug(spec, W=5, F=30, T=40, num_freq_masks=2, num_time_masks=2, replace_with_zero=False)
¶
SpecAugment
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
(https://arxiv.org/pdf/1904.08779.pdf)
This implementation modified from https://github.com/zcaceres/spec_augment
:param torch.Tensor spec: input tensor with the shape (T, dim) :param int W: time warp parameter :param int F: maximum width of each freq mask :param int T: maximum width of each time mask :param int num_freq_masks: number of frequency masks :param int num_time_masks: number of time masks :param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def specaug(spec, W=5, F=30, T=40, num_freq_masks=2, num_time_masks=2, replace_with_zero=False):
"""SpecAugment
Reference: SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
(https://arxiv.org/pdf/1904.08779.pdf)
This implementation modified from https://github.com/zcaceres/spec_augment
:param torch.Tensor spec: input tensor with the shape (T, dim)
:param int W: time warp parameter
:param int F: maximum width of each freq mask
:param int T: maximum width of each time mask
:param int num_freq_masks: number of frequency masks
:param int num_time_masks: number of time masks
:param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
"""
return time_mask(
freq_mask(time_warp(spec, W=W),
F=F, num_masks=num_freq_masks, replace_with_zero=replace_with_zero),
T=T, num_masks=num_time_masks, replace_with_zero=replace_with_zero)
time_mask(spec, T=40, num_masks=1, replace_with_zero=False)
¶
Time masking
:param torch.Tensor spec: input tensor with shape (T, dim) :param int T: maximum width of each mask :param int num_masks: number of masks :param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def time_mask(spec, T=40, num_masks=1, replace_with_zero=False):
"""Time masking
:param torch.Tensor spec: input tensor with shape (T, dim)
:param int T: maximum width of each mask
:param int num_masks: number of masks
:param bool replace_with_zero: if True, masked parts will be filled with 0, if False, filled with mean
"""
cloned = spec.unsqueeze(0).clone()
len_spectro = cloned.shape[1]
for i in range(0, num_masks):
t = random.randrange(0, T)
t_zero = random.randrange(0, len_spectro - t)
# avoids randrange error if values are equal and range is empty
if (t_zero == t_zero + t):
return cloned.squeeze(0)
mask_end = random.randrange(t_zero, t_zero + t)
if (replace_with_zero):
cloned[0][t_zero:mask_end, :] = 0
else:
cloned[0][t_zero:mask_end, :] = cloned.mean()
return cloned.squeeze(0)
time_warp(spec, W=5)
¶
Time warping
:param torch.Tensor spec: input tensor with shape (T, dim) :param int W: time warp parameter
Source code in adviser/tools/espnet_minimal/utils/spec_augment.py
def time_warp(spec, W=5):
"""Time warping
:param torch.Tensor spec: input tensor with shape (T, dim)
:param int W: time warp parameter
"""
spec = spec.unsqueeze(0)
spec_len = spec.shape[1]
num_rows = spec.shape[2]
device = spec.device
y = num_rows // 2
horizontal_line_at_ctr = spec[0, :, y]
assert len(horizontal_line_at_ctr) == spec_len
point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)]
assert isinstance(point_to_warp, torch.Tensor)
# Uniform distribution from (0,W) with chance to be up to W negative
dist_to_warp = random.randrange(-W, W)
src_pts, dest_pts = (torch.tensor([[[point_to_warp, y]]], device=device),
torch.tensor([[[point_to_warp + dist_to_warp, y]]], device=device))
warped_spectro, dense_flows = sparse_image_warp(spec, src_pts, dest_pts)
return warped_spectro.squeeze(3).squeeze(0)
training
special
¶
batchfy
¶
BATCH_COUNT_CHOICES
¶
BATCH_SORT_KEY_CHOICES
¶
batchfy_by_bin(sorted_data, batch_bins, num_batches=0, min_batch_size=1, shortest_first=False, ikey='input', okey='output')
¶
Make variably sized batch set, which maximizes the number of bins up to batch_bins
.
:param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json
:param int batch_bins: Maximum frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every test
batches
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
Source code in adviser/tools/espnet_minimal/utils/training/batchfy.py
def batchfy_by_bin(sorted_data, batch_bins, num_batches=0, min_batch_size=1, shortest_first=False,
ikey="input", okey="output"):
"""Make variably sized batch set, which maximizes the number of bins up to `batch_bins`.
:param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json
:param int batch_bins: Maximum frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every `test` batches
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
"""
if batch_bins <= 0:
raise ValueError(f"invalid batch_bins={batch_bins}")
length = len(sorted_data)
idim = int(sorted_data[0][1][ikey][0]['shape'][1])
odim = int(sorted_data[0][1][okey][0]['shape'][1])
logging.info('# utts: ' + str(len(sorted_data)))
minibatches = []
start = 0
n = 0
while True:
# Dynamic batch size depending on size of samples
b = 0
next_size = 0
max_olen = 0
while next_size < batch_bins and (start + b) < length:
ilen = int(sorted_data[start + b][1][ikey][0]['shape'][0]) * idim
olen = int(sorted_data[start + b][1][okey][0]['shape'][0]) * odim
if olen > max_olen:
max_olen = olen
next_size = (max_olen + ilen) * (b + 1)
if next_size <= batch_bins:
b += 1
elif next_size == 0:
raise ValueError(
f"Can't fit one sample in batch_bins ({batch_bins}): Please increase the value")
end = min(length, start + max(min_batch_size, b))
batch = sorted_data[start:end]
if shortest_first:
batch.reverse()
minibatches.append(batch)
# Check for min_batch_size and fixes the batches if needed
i = -1
while len(minibatches[i]) < min_batch_size:
missing = min_batch_size - len(minibatches[i])
if -i == len(minibatches):
minibatches[i + 1].extend(minibatches[i])
minibatches = minibatches[1:]
break
else:
minibatches[i].extend(minibatches[i - 1][:missing])
minibatches[i - 1] = minibatches[i - 1][missing:]
i -= 1
if end == length:
break
start = end
n += 1
if num_batches > 0:
minibatches = minibatches[:num_batches]
lengths = [len(x) for x in minibatches]
logging.info(str(len(minibatches)) + " batches containing from " +
str(min(lengths)) + " to " + str(max(lengths)) + " samples " +
"(avg " + str(int(np.mean(lengths))) + " samples).")
return minibatches
batchfy_by_frame(sorted_data, max_frames_in, max_frames_out, max_frames_inout, num_batches=0, min_batch_size=1, shortest_first=False, ikey='input', okey='output')
¶
Make variably sized batch set, which maximizes the number of frames to max_batch_frame.
:param Dict[str, Dict[str, Any]] sorteddata: dictionary loaded from data.json
:param int max_frames_in: Maximum input frames of a batch
:param int max_frames_out: Maximum output frames of a batch
:param int max_frames_inout: Maximum input+output frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every test
batches
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) :param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
Source code in adviser/tools/espnet_minimal/utils/training/batchfy.py
def batchfy_by_frame(sorted_data, max_frames_in, max_frames_out, max_frames_inout,
num_batches=0, min_batch_size=1, shortest_first=False,
ikey="input", okey="output"):
"""Make variably sized batch set, which maximizes the number of frames to max_batch_frame.
:param Dict[str, Dict[str, Any]] sorteddata: dictionary loaded from data.json
:param int max_frames_in: Maximum input frames of a batch
:param int max_frames_out: Maximum output frames of a batch
:param int max_frames_inout: Maximum input+output frames of a batch
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param int test: Return only every `test` batches
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
"""
if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
raise ValueError(
f"At least, one of `--batch-frames-in`, `--batch-frames-out` or `--batch-frames-inout` should be > 0")
length = len(sorted_data)
minibatches = []
start = 0
end = 0
while end != length:
# Dynamic batch size depending on size of samples
b = 0
max_olen = 0
max_ilen = 0
while (start + b) < length:
ilen = int(sorted_data[start + b][1][ikey][0]['shape'][0])
if ilen > max_frames_in and max_frames_in != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-in ({max_frames_in}): Please increase the value")
olen = int(sorted_data[start + b][1][okey][0]['shape'][0])
if olen > max_frames_out and max_frames_out != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-out ({max_frames_out}): Please increase the value")
if ilen + olen > max_frames_inout and max_frames_inout != 0:
raise ValueError(
f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): Please increase the value")
max_olen = max(max_olen, olen)
max_ilen = max(max_ilen, ilen)
in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
inout_ok = (max_ilen + max_olen) * (b + 1) <= max_frames_inout or max_frames_inout == 0
if in_ok and out_ok and inout_ok:
# add more seq in the minibatch
b += 1
else:
# no more seq in the minibatch
break
end = min(length, start + b)
batch = sorted_data[start:end]
if shortest_first:
batch.reverse()
minibatches.append(batch)
# Check for min_batch_size and fixes the batches if needed
i = -1
while len(minibatches[i]) < min_batch_size:
missing = min_batch_size - len(minibatches[i])
if -i == len(minibatches):
minibatches[i + 1].extend(minibatches[i])
minibatches = minibatches[1:]
break
else:
minibatches[i].extend(minibatches[i - 1][:missing])
minibatches[i - 1] = minibatches[i - 1][missing:]
i -= 1
start = end
if num_batches > 0:
minibatches = minibatches[:num_batches]
lengths = [len(x) for x in minibatches]
logging.info(str(len(minibatches)) + " batches containing from " +
str(min(lengths)) + " to " + str(max(lengths)) + " samples" +
"(avg " + str(int(np.mean(lengths))) + " samples).")
return minibatches
batchfy_by_seq(sorted_data, batch_size, max_length_in, max_length_out, min_batch_size=1, shortest_first=False, ikey='input', iaxis=0, okey='output', oaxis=0)
¶
Make batch set from json dictionary
:param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json :param int batch_size: batch size :param int max_length_in: maximum length of input to decide adaptive batch size :param int max_length_out: maximum length of output to decide adaptive batch size :param int min_batch_size: mininum batch size (for multi-gpu) :param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS, MT ikey="output".) :param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".) :param str okey: key to access output (for ASR, MT okey="output". for TTS okey="input".) :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
Source code in adviser/tools/espnet_minimal/utils/training/batchfy.py
def batchfy_by_seq(
sorted_data, batch_size, max_length_in, max_length_out,
min_batch_size=1, shortest_first=False,
ikey="input", iaxis=0, okey="output", oaxis=0):
"""Make batch set from json dictionary
:param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json
:param int batch_size: batch size
:param int max_length_in: maximum length of input to decide adaptive batch size
:param int max_length_out: maximum length of output to decide adaptive batch size
:param int min_batch_size: mininum batch size (for multi-gpu)
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:param str ikey: key to access input (for ASR ikey="input", for TTS, MT ikey="output".)
:param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".)
:param str okey: key to access output (for ASR, MT okey="output". for TTS okey="input".)
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research,
-1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
"""
if batch_size <= 0:
raise ValueError(f"Invalid batch_size={batch_size}")
# check #utts is more than min_batch_size
if len(sorted_data) < min_batch_size:
raise ValueError(f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size}).")
# make list of minibatches
minibatches = []
start = 0
while True:
_, info = sorted_data[start]
ilen = int(info[ikey][iaxis]['shape'][0])
olen = int(info[okey][oaxis]['shape'][0]) if oaxis >= 0 else max(map(lambda x: int(x['shape'][0]), info[okey]))
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
# change batchsize depending on the input and output length
# if ilen = 1000 and max_length_in = 800
# then b = batchsize / 2
# and max(min_batches, .) avoids batchsize = 0
bs = max(min_batch_size, int(batch_size / (1 + factor)))
end = min(len(sorted_data), start + bs)
minibatch = sorted_data[start:end]
if shortest_first:
minibatch.reverse()
# check each batch is more than minimum batchsize
if len(minibatch) < min_batch_size:
mod = min_batch_size - len(minibatch) % min_batch_size
additional_minibatch = [sorted_data[i]
for i in np.random.randint(0, start, mod)]
if shortest_first:
additional_minibatch.reverse()
minibatch.extend(additional_minibatch)
minibatches.append(minibatch)
if end == len(sorted_data):
break
start = end
# batch: List[List[Tuple[str, dict]]]
return minibatches
batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first)
¶
Source code in adviser/tools/espnet_minimal/utils/training/batchfy.py
def batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first):
import random
logging.info('use shuffled batch.')
sorted_data = random.sample(data.items(), len(data.items()))
logging.info('# utts: ' + str(len(sorted_data)))
# make list of minibatches
minibatches = []
start = 0
while True:
end = min(len(sorted_data), start + batch_size)
# check each batch is more than minimum batchsize
minibatch = sorted_data[start:end]
if shortest_first:
minibatch.reverse()
if len(minibatch) < min_batch_size:
mod = min_batch_size - len(minibatch) % min_batch_size
additional_minibatch = [sorted_data[i] for i in np.random.randint(0, start, mod)]
if shortest_first:
additional_minibatch.reverse()
minibatch.extend(additional_minibatch)
minibatches.append(minibatch)
if end == len(sorted_data):
break
start = end
# for debugging
if num_batches > 0:
minibatches = minibatches[:num_batches]
logging.info('# minibatches: ' + str(len(minibatches)))
return minibatches
make_batchset(data, batch_size=0, max_length_in=inf, max_length_out=inf, num_batches=0, min_batch_size=1, shortest_first=False, batch_sort_key='input', swap_io=False, mt=False, count='auto', batch_bins=0, batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0, iaxis=0, oaxis=0)
¶
Make batch set from json dictionary
if utts have "category" value,
>>> data = {'utt1': {'category': 'A', 'input': ...},
... 'utt2': {'category': 'B', 'input': ...},
... 'utt3': {'category': 'B', 'input': ...},
... 'utt4': {'category': 'A', 'input': ...}}
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
Note that if any utts doesn't have "category", perform as same as batchfy_by_{count}
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json :param int batch_size: maximum number of sequences in a minibatch. :param int batch_bins: maximum number of bins (frames x dim) in a minibatch. :param int batch_frames_in: maximum number of input frames in a minibatch. :param int batch_frames_out: maximum number of output frames in a minibatch. :param int batch_frames_out: maximum number of input+output frames in a minibatch. :param str count: strategy to count maximum size of batch. For choices, see services.hci.speech.espnet_minimal.asr.batchfy.BATCH_COUNT_CHOICES
:param int max_length_in: maximum length of input to decide adaptive batch size
:param int max_length_out: maximum length of output to decide adaptive batch size
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:return: List[List[Tuple[str, dict]]] list of batches
:param str batch_sort_key: how to sort data before creating minibatches ["input", "output", "shuffle"]
:param bool swap_io: if True, use "input" as output and "output" as input in data
dict
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output" as input in data
dict
:param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".)
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research,
-1 means all axis.)
Source code in adviser/tools/espnet_minimal/utils/training/batchfy.py
def make_batchset(data, batch_size=0, max_length_in=float("inf"), max_length_out=float("inf"),
num_batches=0, min_batch_size=1, shortest_first=False, batch_sort_key="input",
swap_io=False, mt=False, count="auto",
batch_bins=0, batch_frames_in=0, batch_frames_out=0, batch_frames_inout=0,
iaxis=0, oaxis=0):
"""Make batch set from json dictionary
if utts have "category" value,
>>> data = {'utt1': {'category': 'A', 'input': ...},
... 'utt2': {'category': 'B', 'input': ...},
... 'utt3': {'category': 'B', 'input': ...},
... 'utt4': {'category': 'A', 'input': ...}}
>>> make_batchset(data, batchsize=2, ...)
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
Note that if any utts doesn't have "category",
perform as same as batchfy_by_{count}
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json
:param int batch_size: maximum number of sequences in a minibatch.
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
:param int batch_frames_in: maximum number of input frames in a minibatch.
:param int batch_frames_out: maximum number of output frames in a minibatch.
:param int batch_frames_out: maximum number of input+output frames in a minibatch.
:param str count: strategy to count maximum size of batch.
For choices, see services.hci.speech.espnet_minimal.asr.batchfy.BATCH_COUNT_CHOICES
:param int max_length_in: maximum length of input to decide adaptive batch size
:param int max_length_out: maximum length of output to decide adaptive batch size
:param int num_batches: # number of batches to use (for debug)
:param int min_batch_size: minimum batch size (for multi-gpu)
:param bool shortest_first: Sort from batch with shortest samples to longest if true, otherwise reverse
:return: List[List[Tuple[str, dict]]] list of batches
:param str batch_sort_key: how to sort data before creating minibatches ["input", "output", "shuffle"]
:param bool swap_io: if True, use "input" as output and "output" as input in `data` dict
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output" as input in `data` dict
:param int iaxis: dimension to access input (for ASR, TTS iaxis=0, for MT iaxis="1".)
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, reserved for future research,
-1 means all axis.)
"""
# check args
if count not in BATCH_COUNT_CHOICES:
raise ValueError(f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be one of {BATCH_SORT_KEY_CHOICES}")
# TODO(karita): remove this by creating converter from ASR to TTS json format
batch_sort_axis = 0
if swap_io:
# for TTS
ikey = "output"
okey = "input"
if batch_sort_key == "input":
batch_sort_key = "output"
elif batch_sort_key == "output":
batch_sort_key = "input"
elif mt:
# for MT
ikey = "output"
okey = "output"
batch_sort_key = "output"
batch_sort_axis = 1
assert iaxis == 1
assert oaxis == 0
# NOTE: input is json['output'][1] and output is json['output'][0]
else:
ikey = "input"
okey = "output"
if count == "auto":
if batch_size != 0:
count = "seq"
elif batch_bins != 0:
count = "bin"
elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
count = "frame"
else:
raise ValueError(f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}")
logging.info(f"count is auto detected as {count}")
if count != "seq" and batch_sort_key == "shuffle":
raise ValueError(f"batch_sort_key=shuffle is only available if batch_count=seq")
category2data = {} # Dict[str, dict]
for k, v in data.items():
category2data.setdefault(v.get('category'), {})[k] = v
batches_list = [] # List[List[List[Tuple[str, dict]]]]
for d in category2data.values():
if batch_sort_key == 'shuffle':
batches = batchfy_shuffle(d, batch_size, min_batch_size, num_batches, shortest_first)
batches_list.append(batches)
continue
# sort it by input lengths (long to short)
sorted_data = sorted(d.items(), key=lambda data: int(
data[1][batch_sort_key][batch_sort_axis]['shape'][0]), reverse=not shortest_first)
logging.info('# utts: ' + str(len(sorted_data)))
if count == "seq":
batches = batchfy_by_seq(
sorted_data,
batch_size=batch_size,
max_length_in=max_length_in,
max_length_out=max_length_out,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey, iaxis=iaxis, okey=okey, oaxis=oaxis)
if count == "bin":
batches = batchfy_by_bin(
sorted_data,
batch_bins=batch_bins,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey, okey=okey)
if count == "frame":
batches = batchfy_by_frame(
sorted_data,
max_frames_in=batch_frames_in,
max_frames_out=batch_frames_out,
max_frames_inout=batch_frames_inout,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
ikey=ikey, okey=okey)
batches_list.append(batches)
if len(batches_list) == 1:
batches = batches_list[0]
else:
# Concat list. This way is faster than "sum(batch_list, [])"
batches = list(itertools.chain(*batches_list))
# for debugging
if num_batches > 0:
batches = batches[:num_batches]
logging.info('# minibatches: ' + str(len(batches)))
# batch: List[List[Tuple[str, dict]]]
return batches
evaluator
¶
BaseEvaluator (Evaluator)
¶
Base Evaluator in ESPnet
Source code in adviser/tools/espnet_minimal/utils/training/evaluator.py
class BaseEvaluator(Evaluator):
"""Base Evaluator in ESPnet"""
def __call__(self, trainer=None):
ret = super().__call__(trainer)
try:
if trainer is not None:
# force tensorboard to report evaluation log
tb_logger = trainer.get_extension(TensorboardLogger.default_name)
tb_logger(trainer)
except ValueError:
pass
return ret
__call__(self, trainer=None)
special
¶Source code in adviser/tools/espnet_minimal/utils/training/evaluator.py
iterators
¶
ShufflingEnabler (Extension)
¶
An extension enabling shuffling on an Iterator
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
class ShufflingEnabler(Extension):
"""An extension enabling shuffling on an Iterator"""
def __init__(self, iterators):
"""Inits the ShufflingEnabler
:param list[Iterator] iterators: The iterators to enable shuffling on
"""
self.set = False
self.iterators = iterators
def __call__(self, trainer):
"""Calls the enabler on the given iterator
:param trainer: The iterator
"""
if not self.set:
for iterator in self.iterators:
iterator.start_shuffle()
self.set = True
ToggleableShufflingMultiprocessIterator (MultiprocessIterator)
¶
A MultiprocessIterator that can have its shuffling property activated during training
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
class ToggleableShufflingMultiprocessIterator(MultiprocessIterator):
"""A MultiprocessIterator that can have its shuffling property activated during training"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True, n_processes=None, n_prefetch=1, shared_mem=None,
maxtasksperchild=20):
"""Init the iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat batches or not (enables multiple epochs)
:param bool shuffle: Whether to shuffle the order of the batches
:param int n_processes: How many processes to use
:param int n_prefetch: The number of prefetch to use
:param int shared_mem: How many memory to share between processes
:param int maxtasksperchild: Maximum number of tasks per child
"""
super(ToggleableShufflingMultiprocessIterator, self).__init__(dataset=dataset, batch_size=batch_size,
repeat=repeat, shuffle=shuffle,
n_processes=n_processes,
n_prefetch=n_prefetch, shared_mem=shared_mem,
maxtasksperchild=maxtasksperchild)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self.shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
self._set_prefetch_state()
__init__(self, dataset, batch_size, repeat=True, shuffle=True, n_processes=None, n_prefetch=1, shared_mem=None, maxtasksperchild=20)
special
¶Init the iterator
:param torch.nn.Tensor dataset: The dataset to take batches from :param int batch_size: The batch size :param bool repeat: Whether to repeat batches or not (enables multiple epochs) :param bool shuffle: Whether to shuffle the order of the batches :param int n_processes: How many processes to use :param int n_prefetch: The number of prefetch to use :param int shared_mem: How many memory to share between processes :param int maxtasksperchild: Maximum number of tasks per child
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
def __init__(self, dataset, batch_size, repeat=True, shuffle=True, n_processes=None, n_prefetch=1, shared_mem=None,
maxtasksperchild=20):
"""Init the iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat batches or not (enables multiple epochs)
:param bool shuffle: Whether to shuffle the order of the batches
:param int n_processes: How many processes to use
:param int n_prefetch: The number of prefetch to use
:param int shared_mem: How many memory to share between processes
:param int maxtasksperchild: Maximum number of tasks per child
"""
super(ToggleableShufflingMultiprocessIterator, self).__init__(dataset=dataset, batch_size=batch_size,
repeat=repeat, shuffle=shuffle,
n_processes=n_processes,
n_prefetch=n_prefetch, shared_mem=shared_mem,
maxtasksperchild=maxtasksperchild)
start_shuffle(self)
¶Starts shuffling (or reshuffles) the batches
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self.shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
self._set_prefetch_state()
ToggleableShufflingSerialIterator (SerialIterator)
¶
A SerialIterator that can have its shuffling property activated during training
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
class ToggleableShufflingSerialIterator(SerialIterator):
"""A SerialIterator that can have its shuffling property activated during training"""
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
"""Init the Iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat data (allow multiple epochs)
:param bool shuffle: Whether to shuffle the batches
"""
super(ToggleableShufflingSerialIterator, self).__init__(dataset, batch_size, repeat, shuffle)
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self._shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
__init__(self, dataset, batch_size, repeat=True, shuffle=True)
special
¶Init the Iterator
:param torch.nn.Tensor dataset: The dataset to take batches from :param int batch_size: The batch size :param bool repeat: Whether to repeat data (allow multiple epochs) :param bool shuffle: Whether to shuffle the batches
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
def __init__(self, dataset, batch_size, repeat=True, shuffle=True):
"""Init the Iterator
:param torch.nn.Tensor dataset: The dataset to take batches from
:param int batch_size: The batch size
:param bool repeat: Whether to repeat data (allow multiple epochs)
:param bool shuffle: Whether to shuffle the batches
"""
super(ToggleableShufflingSerialIterator, self).__init__(dataset, batch_size, repeat, shuffle)
start_shuffle(self)
¶Starts shuffling (or reshuffles) the batches
Source code in adviser/tools/espnet_minimal/utils/training/iterators.py
def start_shuffle(self):
"""Starts shuffling (or reshuffles) the batches"""
self._shuffle = True
if int(chainer._version.__version__[0]) <= 4:
self._order = np.random.permutation(len(self.dataset))
else:
self.order_sampler = ShuffleOrderSampler()
self._order = self.order_sampler(np.arange(len(self.dataset)), 0)
tensorboard_logger
¶
TensorboardLogger (Extension)
¶
A tensorboard logger extension
Source code in adviser/tools/espnet_minimal/utils/training/tensorboard_logger.py
class TensorboardLogger(Extension):
"""A tensorboard logger extension"""
default_name = "espnet_tensorboard_logger"
def __init__(self, logger, att_reporter=None, entries=None, epoch=0):
"""Init the extension
:param SummaryWriter logger: The logger to use
:param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter
:param entries: The entries to watch
:param int epoch: The starting epoch
"""
self._entries = entries
self._att_reporter = att_reporter
self._logger = logger
self._epoch = epoch
def __call__(self, trainer):
"""Updates the events file with the new values
:param trainer: The trainer
"""
observation = trainer.observation
for k, v in observation.items():
if (self._entries is not None) and (k not in self._entries):
continue
if k is not None and v is not None:
if 'cupy' in str(type(v)):
v = v.get()
if 'cupy' in str(type(k)):
k = k.get()
self._logger.add_scalar(k, v, trainer.updater.iteration)
if self._att_reporter is not None and trainer.updater.get_iterator('main').epoch > self._epoch:
self._epoch = trainer.updater.get_iterator('main').epoch
self._att_reporter.log_attentions(self._logger, trainer.updater.iteration)
default_name
¶__call__(self, trainer)
special
¶Updates the events file with the new values
:param trainer: The trainer
Source code in adviser/tools/espnet_minimal/utils/training/tensorboard_logger.py
def __call__(self, trainer):
"""Updates the events file with the new values
:param trainer: The trainer
"""
observation = trainer.observation
for k, v in observation.items():
if (self._entries is not None) and (k not in self._entries):
continue
if k is not None and v is not None:
if 'cupy' in str(type(v)):
v = v.get()
if 'cupy' in str(type(k)):
k = k.get()
self._logger.add_scalar(k, v, trainer.updater.iteration)
if self._att_reporter is not None and trainer.updater.get_iterator('main').epoch > self._epoch:
self._epoch = trainer.updater.get_iterator('main').epoch
self._att_reporter.log_attentions(self._logger, trainer.updater.iteration)
__init__(self, logger, att_reporter=None, entries=None, epoch=0)
special
¶Init the extension
:param SummaryWriter logger: The logger to use :param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter :param entries: The entries to watch :param int epoch: The starting epoch
Source code in adviser/tools/espnet_minimal/utils/training/tensorboard_logger.py
def __init__(self, logger, att_reporter=None, entries=None, epoch=0):
"""Init the extension
:param SummaryWriter logger: The logger to use
:param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter
:param entries: The entries to watch
:param int epoch: The starting epoch
"""
self._entries = entries
self._att_reporter = att_reporter
self._logger = logger
self._epoch = epoch
train_utils
¶
check_early_stop(trainer, epochs)
¶
Checks if the training was stopped by an early stopping trigger and warns the user if it's the case
:param trainer: The trainer used for training :param epochs: The maximum number of epochs
Source code in adviser/tools/espnet_minimal/utils/training/train_utils.py
def check_early_stop(trainer, epochs):
"""Checks if the training was stopped by an early stopping trigger and warns the user if it's the case
:param trainer: The trainer used for training
:param epochs: The maximum number of epochs
"""
end_epoch = trainer.updater.get_iterator('main').epoch
if end_epoch < (epochs - 1):
logging.warning("Hit early stop at epoch " + str(
end_epoch) + "\nYou can change the patience or set it to 0 to run all epochs")
set_early_stop(trainer, args, is_lm=False)
¶
Sets the early stop trigger given the program arguments
:param trainer: The trainer used for training :param args: The program arguments :param is_lm: If the trainer is for a LM (epoch instead of epochs)
Source code in adviser/tools/espnet_minimal/utils/training/train_utils.py
def set_early_stop(trainer, args, is_lm=False):
"""Sets the early stop trigger given the program arguments
:param trainer: The trainer used for training
:param args: The program arguments
:param is_lm: If the trainer is for a LM (epoch instead of epochs)
"""
patience = args.patience
criterion = args.early_stop_criterion
epochs = args.epoch if is_lm else args.epochs
mode = 'max' if 'acc' in criterion else 'min'
if patience > 0:
trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger(monitor=criterion,
mode=mode,
patients=patience,
max_trigger=(epochs, 'epoch'))