Utils¶
beliefstate
¶
This module provides the BeliefState class.
BeliefState
¶
A representation of the belief state, can be accessed like a dictionary.
Includes information on: * current set of UserActTypes * informs to date (dictionary where key is slot and value is {value: probaility}) * requests for the current turn * number of db matches for given constraints * if the db matches can further be split
Source code in adviser/utils/beliefstate.py
class BeliefState:
"""
A representation of the belief state, can be accessed like a dictionary.
Includes information on:
* current set of UserActTypes
* informs to date (dictionary where key is slot and value is {value: probaility})
* requests for the current turn
* number of db matches for given constraints
* if the db matches can further be split
"""
def __init__(self, domain: JSONLookupDomain):
self.domain = domain
self._history = [self._init_beliefstate()]
def dialog_start(self):
self._history = [self._init_beliefstate()]
def __getitem__(self, val): # for indexing
# if used with numbers: int (e.g. state[-2]) or slice (e.g. state[3:6])
if isinstance(val, int) or isinstance(val, slice):
return self._history[val] # interpret the number as turn
# if used with strings (e.g. state['beliefs'])
elif isinstance(val, str):
# take the current turn's belief state
return self._history[-1][val]
def __iter__(self):
return iter(self._history[-1])
def __setitem__(self, key, val):
# e.g. state['beliefs']['area']['west'] = 1.0
self._history[-1][key] = val
def __len__(self):
return len(self._history)
def __contains__(self, val): # assume
return val in self._history[-1]
def _recursive_repr(self, sub_dict, indent=0):
# if isinstance(sub_dict, type(None)):
# return ""
string = ""
if isinstance(sub_dict, dict):
string += '{'
for key in sub_dict:
string += "'" + str(key) + "': "
string += self._recursive_repr(sub_dict[key], indent + 2)
string += '}\n' + ' ' * indent
else:
string += str(sub_dict) + ' '
return string
def __repr__(self):
return str(self._history[-1])
def __str__(self):
return self._recursive_repr(self._history[-1])
def start_new_turn(self):
"""
ONLY to be called by the belief state tracker at the begin of each turn,
to ensure the correct history can be accessed correctly by other modules
"""
# copy last turn's dict
self._history.append(copy.deepcopy(self._history[-1]))
def _init_beliefstate(self):
"""Initializes the belief state based on the currently active domain
Returns:
(dict): nested dict of slots/values and system belief of
each state"
"""
# TODO: revist when we include probabilites, sets should become dictionaries
belief_state = {"user_acts": set(),
"informs": {},
"requests": {},
"num_matches": 0,
"discriminable": True}
return belief_state
def get_most_probable_slot_beliefs(self, slot: str, consider_NONE: bool = True,
threshold: float = 0.7,
max_results: int = 1, turn_idx: int = -1):
""" Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold,
then the slot will not be added to the result at all.
Args:
consider_NONE: If True, slots where **NONE** values have the
highest probability will not be added to the result.
If False, slots where **NONE** values have the
highest probability will look for the best value !=
**NONE**.
threshold: minimum probability to be accepted to the
max_results: return at most #max_results best values per slot
turn_idx: index for accessing the belief state history (default = -1: use last turn)
Returns:
Union(Dict[str, List[str]], Dict[str, str]): A dict with mapping from slots to a list of values (if max_results > 1) or
a value (if max_results == 1) containing the slots which
have at least one value whose probability exceeds the specified
threshold.
"""
informs = self._history[turn_idx]["informs"]
candidates = []
if slot in informs:
sorted_slot_cands = sorted(informs[slot].items(), key=lambda kv: kv[1], reverse=True)
# restrict result count to specified maximum
filtered_slot_cands = sorted_slot_cands[:max_results]
# threshold by probabilities
filtered_slot_cands = [slot_cand[0] for slot_cand in filtered_slot_cands
if slot_cand[1] >= threshold]
return candidates
def get_most_probable_inf_beliefs(self, consider_NONE: bool = True, threshold: float = 0.7,
max_results: int = 1, turn_idx: int = -1):
""" Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold,
then the slot will not be added to the result at all.
Args:
consider_NONE: If True, slots where **NONE** values have the
highest probability will not be added to the result.
If False, slots where **NONE** values have the
highest probability will look for the best value !=
**NONE**.
threshold: minimum probability to be accepted
max_results: return at most `max_results` best values per slot
turn_idx: index for accessing the belief state history (default = -1: use last turn)
Returns:
Union(Dict[str, List[str]], Dict[str, str]): A dict with mapping from slots to a list of values (if max_results > 1) or
a value (if max_results == 1) containing the slots which
have at least one value whose probability exceeds the specified
threshold.
"""
candidates = {}
informs = self._history[turn_idx]["informs"]
for slot in informs:
# sort by belief
sorted_slot_cands = sorted(informs[slot].items(), key=lambda kv: kv[1], reverse=True)
# restrict result count to specified maximum
filtered_slot_cands = sorted_slot_cands[:max_results]
# threshold by probabilities
filtered_slot_cands = [slot_cand[0] for slot_cand in filtered_slot_cands
if slot_cand[1] >= threshold]
if len(filtered_slot_cands) > 0:
# append results if any remain after filtering
if max_results == 1:
# only float
candidates[slot] = filtered_slot_cands[0]
else:
# list
candidates[slot] = filtered_slot_cands
return candidates
def get_requested_slots(self, turn_idx: int = -1):
""" Returns the slots requested by the user
Args:
turn_idx: index for accessing the belief state history (default = -1: use last turn)
"""
candidates = []
for req_slot in self._history[turn_idx]['requests']:
candidates.append(req_slot)
return candidates
def _remove_dontcare_slots(self, slot_value_dict: dict):
""" Returns a new dictionary without the slots set to dontcare """
return {slot: value for slot, value in slot_value_dict.items()
if value != 'dontcare'}
def get_num_dbmatches(self):
""" Updates the belief state's entry for the number of database matches given the
constraints in the current turn.
"""
# check how many db entities match the current constraints
candidates = self.get_most_probable_inf_beliefs(consider_NONE=True, threshold=0.7,
max_results=1)
constraints = self._remove_dontcare_slots(candidates)
db_matches = self.domain.find_entities(constraints, self.domain.get_informable_slots())
num_matches = len(db_matches)
# check if matching db entities could be discriminated by more
# information from user
discriminable = False
if len(db_matches) > 1:
dontcare_slots = set(candidates.keys()) - set(constraints.keys())
informable_slots = set(self.domain.get_informable_slots()) - set(self.domain.get_primary_key())
for informable_slot in informable_slots:
if informable_slot not in dontcare_slots:
# this slot could be used to gather more information
db_values_for_slot = set()
for db_match in db_matches:
db_values_for_slot.add(db_match[informable_slot])
if len(db_values_for_slot) > 1:
# at least 2 different values for slot
# ->can use this slot to differentiate between entities
discriminable = True
break
return num_matches, discriminable
__contains__(self, val)
special
¶
__getitem__(self, val)
special
¶
Source code in adviser/utils/beliefstate.py
def __getitem__(self, val): # for indexing
# if used with numbers: int (e.g. state[-2]) or slice (e.g. state[3:6])
if isinstance(val, int) or isinstance(val, slice):
return self._history[val] # interpret the number as turn
# if used with strings (e.g. state['beliefs'])
elif isinstance(val, str):
# take the current turn's belief state
return self._history[-1][val]
__init__(self, domain)
special
¶
__iter__(self)
special
¶
__len__(self)
special
¶
__repr__(self)
special
¶
__setitem__(self, key, val)
special
¶
__str__(self)
special
¶
dialog_start(self)
¶
get_most_probable_inf_beliefs(self, consider_NONE=True, threshold=0.7, max_results=1, turn_idx=-1)
¶
Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold, then the slot will not be added to the result at all.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
consider_NONE |
bool |
If True, slots where NONE values have the highest probability will not be added to the result. If False, slots where NONE values have the highest probability will look for the best value != NONE. |
True |
threshold |
float |
minimum probability to be accepted |
0.7 |
max_results |
int |
return at most |
1 |
turn_idx |
int |
index for accessing the belief state history (default = -1: use last turn) |
-1 |
Returns:
Type | Description |
---|---|
Union(Dict[str, List[str]], Dict[str, str]) |
A dict with mapping from slots to a list of values (if max_results > 1) or a value (if max_results == 1) containing the slots which have at least one value whose probability exceeds the specified threshold. |
Source code in adviser/utils/beliefstate.py
def get_most_probable_inf_beliefs(self, consider_NONE: bool = True, threshold: float = 0.7,
max_results: int = 1, turn_idx: int = -1):
""" Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold,
then the slot will not be added to the result at all.
Args:
consider_NONE: If True, slots where **NONE** values have the
highest probability will not be added to the result.
If False, slots where **NONE** values have the
highest probability will look for the best value !=
**NONE**.
threshold: minimum probability to be accepted
max_results: return at most `max_results` best values per slot
turn_idx: index for accessing the belief state history (default = -1: use last turn)
Returns:
Union(Dict[str, List[str]], Dict[str, str]): A dict with mapping from slots to a list of values (if max_results > 1) or
a value (if max_results == 1) containing the slots which
have at least one value whose probability exceeds the specified
threshold.
"""
candidates = {}
informs = self._history[turn_idx]["informs"]
for slot in informs:
# sort by belief
sorted_slot_cands = sorted(informs[slot].items(), key=lambda kv: kv[1], reverse=True)
# restrict result count to specified maximum
filtered_slot_cands = sorted_slot_cands[:max_results]
# threshold by probabilities
filtered_slot_cands = [slot_cand[0] for slot_cand in filtered_slot_cands
if slot_cand[1] >= threshold]
if len(filtered_slot_cands) > 0:
# append results if any remain after filtering
if max_results == 1:
# only float
candidates[slot] = filtered_slot_cands[0]
else:
# list
candidates[slot] = filtered_slot_cands
return candidates
get_most_probable_slot_beliefs(self, slot, consider_NONE=True, threshold=0.7, max_results=1, turn_idx=-1)
¶
Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold, then the slot will not be added to the result at all.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
consider_NONE |
bool |
If True, slots where NONE values have the highest probability will not be added to the result. If False, slots where NONE values have the highest probability will look for the best value != NONE. |
True |
threshold |
float |
minimum probability to be accepted to the |
0.7 |
max_results |
int |
return at most #max_results best values per slot |
1 |
turn_idx |
int |
index for accessing the belief state history (default = -1: use last turn) |
-1 |
Returns:
Type | Description |
---|---|
Union(Dict[str, List[str]], Dict[str, str]) |
A dict with mapping from slots to a list of values (if max_results > 1) or a value (if max_results == 1) containing the slots which have at least one value whose probability exceeds the specified threshold. |
Source code in adviser/utils/beliefstate.py
def get_most_probable_slot_beliefs(self, slot: str, consider_NONE: bool = True,
threshold: float = 0.7,
max_results: int = 1, turn_idx: int = -1):
""" Extract the most probable value for each system requestable slot
If the most probable value for a slot does not exceed the threshold,
then the slot will not be added to the result at all.
Args:
consider_NONE: If True, slots where **NONE** values have the
highest probability will not be added to the result.
If False, slots where **NONE** values have the
highest probability will look for the best value !=
**NONE**.
threshold: minimum probability to be accepted to the
max_results: return at most #max_results best values per slot
turn_idx: index for accessing the belief state history (default = -1: use last turn)
Returns:
Union(Dict[str, List[str]], Dict[str, str]): A dict with mapping from slots to a list of values (if max_results > 1) or
a value (if max_results == 1) containing the slots which
have at least one value whose probability exceeds the specified
threshold.
"""
informs = self._history[turn_idx]["informs"]
candidates = []
if slot in informs:
sorted_slot_cands = sorted(informs[slot].items(), key=lambda kv: kv[1], reverse=True)
# restrict result count to specified maximum
filtered_slot_cands = sorted_slot_cands[:max_results]
# threshold by probabilities
filtered_slot_cands = [slot_cand[0] for slot_cand in filtered_slot_cands
if slot_cand[1] >= threshold]
return candidates
get_num_dbmatches(self)
¶
Updates the belief state's entry for the number of database matches given the constraints in the current turn.
Source code in adviser/utils/beliefstate.py
def get_num_dbmatches(self):
""" Updates the belief state's entry for the number of database matches given the
constraints in the current turn.
"""
# check how many db entities match the current constraints
candidates = self.get_most_probable_inf_beliefs(consider_NONE=True, threshold=0.7,
max_results=1)
constraints = self._remove_dontcare_slots(candidates)
db_matches = self.domain.find_entities(constraints, self.domain.get_informable_slots())
num_matches = len(db_matches)
# check if matching db entities could be discriminated by more
# information from user
discriminable = False
if len(db_matches) > 1:
dontcare_slots = set(candidates.keys()) - set(constraints.keys())
informable_slots = set(self.domain.get_informable_slots()) - set(self.domain.get_primary_key())
for informable_slot in informable_slots:
if informable_slot not in dontcare_slots:
# this slot could be used to gather more information
db_values_for_slot = set()
for db_match in db_matches:
db_values_for_slot.add(db_match[informable_slot])
if len(db_values_for_slot) > 1:
# at least 2 different values for slot
# ->can use this slot to differentiate between entities
discriminable = True
break
return num_matches, discriminable
get_requested_slots(self, turn_idx=-1)
¶
Returns the slots requested by the user
Parameters:
Name | Type | Description | Default |
---|---|---|---|
turn_idx |
int |
index for accessing the belief state history (default = -1: use last turn) |
-1 |
Source code in adviser/utils/beliefstate.py
def get_requested_slots(self, turn_idx: int = -1):
""" Returns the slots requested by the user
Args:
turn_idx: index for accessing the belief state history (default = -1: use last turn)
"""
candidates = []
for req_slot in self._history[turn_idx]['requests']:
candidates.append(req_slot)
return candidates
start_new_turn(self)
¶
ONLY to be called by the belief state tracker at the begin of each turn, to ensure the correct history can be accessed correctly by other modules
Source code in adviser/utils/beliefstate.py
common
¶
This modules provides a method to seed commonly used random generators.
GLOBAL_SEED
¶
Language (Enum)
¶
init_random(seed=None)
¶
Initializes the random generators to allow seeding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
seed |
int |
The seed used for all random generators. |
None |
Source code in adviser/utils/common.py
def init_random(seed: int = None):
"""
Initializes the random generators to allow seeding.
Args:
seed (int): The seed used for all random generators.
"""
global GLOBAL_SEED # pylint: disable=global-statement
if GLOBAL_SEED is not None:
return
if seed is None:
tmp_random = numpy.random.RandomState(None)
GLOBAL_SEED = tmp_random.randint(2**32-1, dtype='uint32')
else:
GLOBAL_SEED = seed
# initialize random generators
numpy.random.seed(GLOBAL_SEED)
random.seed(GLOBAL_SEED)
try:
# try to load torch and initialize random generator if available
import torch
torch.cuda.manual_seed_all(GLOBAL_SEED) # gpu
torch.manual_seed(GLOBAL_SEED) # cpu
except ImportError:
pass
try:
# try to load tensorflow and initialize random generator if available
import tensorflow
tensorflow.random.set_random_seed(GLOBAL_SEED)
except ImportError:
pass
# check whether all calls to torch.* use the same random generator (i.e. same instance)
# works in a short test -- MS
# print(torch.initial_seed())
# logger.info("Seed is {:d}".format(GLOBAL_SEED))
return GLOBAL_SEED
domain
special
¶
domain
¶
Domain
¶
Abstract class for linking a domain with a data access method.
Derive from this class if you need to implement a domain with a not yet supported data backend, otherwise choose a fitting existing child class.
Source code in adviser/utils/domain/domain.py
class Domain(object):
""" Abstract class for linking a domain with a data access method.
Derive from this class if you need to implement a domain with a not yet
supported data backend, otherwise choose a fitting existing child class. """
def __init__(self, name: str):
self.name = name
def get_domain_name(self) -> str:
""" Return the domain name of the current ontology.
Returns:
object:
"""
return self.name
def find_entities(self, constraints : dict):
""" Returns all entities from the data backend that meet the constraints.
Args:
constraints (dict): slot-value mapping of constraints
IMPORTANT: This function must be overridden!
"""
raise NotImplementedError
__init__(self, name)
special
¶
find_entities(self, constraints)
¶
Returns all entities from the data backend that meet the constraints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constraints |
dict |
slot-value mapping of constraints |
required |
IMPORTANT: This function must be overridden!
get_domain_name(self)
¶
jsonlookupdomain
¶
JSONLookupDomain (Domain)
¶
Abstract class for linking a domain based on a JSON-ontology with a database access method (sqllite).
Source code in adviser/utils/domain/jsonlookupdomain.py
class JSONLookupDomain(Domain):
""" Abstract class for linking a domain based on a JSON-ontology with a database
access method (sqllite).
"""
def __init__(self, name: str, json_ontology_file: str = None, sqllite_db_file: str = None, \
display_name: str = None):
""" Loads the ontology from a json file and the data from a sqllite
database.
To create a new domain using this format, inherit from this class
and overwrite the _get_domain_name_()-method to return your
domain's name.
Arguments:
name (str): the domain's name used as an identifier
json_ontology_file (str): relative path to the ontology file
(from the top-level adviser directory, e.g. resources/ontologies)
sqllite_db_file (str): relative path to the database file
(from the top-level adviser directory, e.g. resources/databases)
display_name (str): the domain's name as it appears on the screen
(e.g. containing whitespaces)
"""
super(JSONLookupDomain, self).__init__(name)
root_dir = self._get_root_dir()
self.sqllite_db_file = sqllite_db_file
# make sure to set default values in case of None
json_ontology_file = json_ontology_file or os.path.join('resources', 'ontologies',
name + '.json')
sqllite_db_file = sqllite_db_file or os.path.join('resources', 'databases',
name + '.db')
self.ontology_json = json.load(open(root_dir + '/' + json_ontology_file))
# load database
self.db = self._load_db_to_memory(root_dir + '/' + sqllite_db_file)
self.display_name = display_name if display_name is not None else name
def __getstate__(self):
# remove sql connection from state dict so that pickling works
state = self.__dict__.copy()
if 'db' in state:
del state['db']
return state
def _get_root_dir(self):
""" Returns the path to the root directory """
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def _sqllite_dict_factory(self, cursor, row):
""" Convert sqllite row into a dictionary """
row_dict = {}
for col_idx, col in enumerate(cursor.description):
# iterate over all columns, get corresponding db value from row
row_dict[col[0]] = row[col_idx]
return row_dict
def _load_db_to_memory(self, db_file_path : str):
""" Loads a sqllite3 database from file to memory in order to save
I/O operations
Args:
db_file_path (str): absolute path to database file
Returns:
A sqllite3 connection
"""
# open and read db file to temporary file
file_db = sqlite3.connect(db_file_path, check_same_thread=False)
tempfile = StringIO()
for line in file_db.iterdump():
tempfile.write('%s\n' % line)
file_db.close()
tempfile.seek(0)
# Create a database in memory and import from temporary file
db = sqlite3.connect(':memory:', check_same_thread=False)
db.row_factory = self._sqllite_dict_factory
db.cursor().executescript(tempfile.read())
db.commit()
# file_db.backup(databases[domain]) # works only in python >= 3.7
return db
def find_entities(self, constraints: dict, requested_slots: Iterable = iter(())):
""" Returns all entities from the data backend that meet the constraints, with values for
the primary key and the system requestable slots (and optional slots, specifyable
via requested_slots).
Args:
constraints (dict): Slot-value mapping of constraints.
If empty, all entities in the database will be returned.
requested_slots (Iterable): list of slots that should be returned in addition to the
system requestable slots and the primary key
"""
# values for name and all system requestable slots
select_clause = ", ".join(set([self.get_primary_key()]) |
set(self.get_system_requestable_slots()) |
set(requested_slots))
query = "SELECT {} FROM {}".format(select_clause, self.get_domain_name())
constraints = {slot: value.replace("'", "''") for slot, value in constraints.items()
if value is not None and str(value).lower() != 'dontcare'}
if constraints:
query += ' WHERE ' + ' AND '.join("{}='{}' COLLATE NOCASE".format(key, str(val))
for key, val in constraints.items())
return self.query_db(query)
def find_info_about_entity(self, entity_id, requested_slots: Iterable):
""" Returns the values (stored in the data backend) of the specified slots for the
specified entity.
Args:
entity_id (str): primary key value of the entity
requested_slots (dict): slot-value mapping of constraints
"""
if requested_slots:
select_clause = ", ".join(sorted(requested_slots))
# If the user hasn't specified any slots we don't know what they want so we give everything
else:
select_clause = "*"
query = 'SELECT {} FROM {} WHERE {}="{}";'.format(
select_clause, self.get_domain_name(), self.get_primary_key(), entity_id)
return self.query_db(query)
def query_db(self, query_str):
""" Function for querying the sqlite3 db
Args:
query_str (string): sqlite3 query style string
Return:
(iterable): rows of the query response set
"""
if "db" not in self.__dict__:
root_dir = self._get_root_dir()
sqllite_db_file = self.sqllite_db_file or os.path.join(
'resources', 'databases', self.name + '.db')
self.db = self._load_db_to_memory(root_dir + '/' + sqllite_db_file)
cursor = self.db.cursor()
cursor.execute(query_str)
res = cursor.fetchall()
return res
def get_display_name(self):
return self.display_name
def get_requestable_slots(self) -> List[str]:
""" Returns a list of all slots requestable by the user. """
return self.ontology_json['requestable']
def get_system_requestable_slots(self) -> List[str]:
""" Returns a list of all slots requestable by the system. """
return self.ontology_json['system_requestable']
def get_informable_slots(self) -> List[str]:
""" Returns a list of all informable slots. """
return self.ontology_json['informable'].keys()
def get_possible_values(self, slot: str) -> List[str]:
""" Returns all possible values for an informable slot
Args:
slot (str): name of the slot
Returns:
a list of strings, each string representing one possible value for
the specified slot.
"""
return self.ontology_json['informable'][slot]
def get_primary_key(self):
""" Returns the name of a column in the associated database which can be used to uniquely
distinguish between database entities.
Could be e.g. the name of a restaurant, an ID, ... """
return self.ontology_json['key']
def get_pronouns(self, slot):
if slot in self.ontology_json['pronoun_map']:
return self.ontology_json['pronoun_map'][slot]
else:
return []
def get_keyword(self):
if "keyword" in self.ontology_json:
return self.ontology_json['keyword']
__getstate__(self)
special
¶
__init__(self, name, json_ontology_file=None, sqllite_db_file=None, display_name=None)
special
¶
Loads the ontology from a json file and the data from a sqllite database.
To create a new domain using this format, inherit from this class
and overwrite the _get_domain_name_()-method to return your
domain's name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
the domain's name used as an identifier |
required |
json_ontology_file |
str |
relative path to the ontology file (from the top-level adviser directory, e.g. resources/ontologies) |
None |
sqllite_db_file |
str |
relative path to the database file (from the top-level adviser directory, e.g. resources/databases) |
None |
display_name |
str |
the domain's name as it appears on the screen (e.g. containing whitespaces) |
None |
Source code in adviser/utils/domain/jsonlookupdomain.py
def __init__(self, name: str, json_ontology_file: str = None, sqllite_db_file: str = None, \
display_name: str = None):
""" Loads the ontology from a json file and the data from a sqllite
database.
To create a new domain using this format, inherit from this class
and overwrite the _get_domain_name_()-method to return your
domain's name.
Arguments:
name (str): the domain's name used as an identifier
json_ontology_file (str): relative path to the ontology file
(from the top-level adviser directory, e.g. resources/ontologies)
sqllite_db_file (str): relative path to the database file
(from the top-level adviser directory, e.g. resources/databases)
display_name (str): the domain's name as it appears on the screen
(e.g. containing whitespaces)
"""
super(JSONLookupDomain, self).__init__(name)
root_dir = self._get_root_dir()
self.sqllite_db_file = sqllite_db_file
# make sure to set default values in case of None
json_ontology_file = json_ontology_file or os.path.join('resources', 'ontologies',
name + '.json')
sqllite_db_file = sqllite_db_file or os.path.join('resources', 'databases',
name + '.db')
self.ontology_json = json.load(open(root_dir + '/' + json_ontology_file))
# load database
self.db = self._load_db_to_memory(root_dir + '/' + sqllite_db_file)
self.display_name = display_name if display_name is not None else name
find_entities(self, constraints, requested_slots=<tuple_iterator object at 0x7f1f8376d0d0>)
¶
Returns all entities from the data backend that meet the constraints, with values for the primary key and the system requestable slots (and optional slots, specifyable via requested_slots).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constraints |
dict |
Slot-value mapping of constraints. If empty, all entities in the database will be returned. |
required |
requested_slots |
Iterable |
list of slots that should be returned in addition to the system requestable slots and the primary key |
<tuple_iterator object at 0x7f1f8376d0d0> |
Source code in adviser/utils/domain/jsonlookupdomain.py
def find_entities(self, constraints: dict, requested_slots: Iterable = iter(())):
""" Returns all entities from the data backend that meet the constraints, with values for
the primary key and the system requestable slots (and optional slots, specifyable
via requested_slots).
Args:
constraints (dict): Slot-value mapping of constraints.
If empty, all entities in the database will be returned.
requested_slots (Iterable): list of slots that should be returned in addition to the
system requestable slots and the primary key
"""
# values for name and all system requestable slots
select_clause = ", ".join(set([self.get_primary_key()]) |
set(self.get_system_requestable_slots()) |
set(requested_slots))
query = "SELECT {} FROM {}".format(select_clause, self.get_domain_name())
constraints = {slot: value.replace("'", "''") for slot, value in constraints.items()
if value is not None and str(value).lower() != 'dontcare'}
if constraints:
query += ' WHERE ' + ' AND '.join("{}='{}' COLLATE NOCASE".format(key, str(val))
for key, val in constraints.items())
return self.query_db(query)
find_info_about_entity(self, entity_id, requested_slots)
¶
Returns the values (stored in the data backend) of the specified slots for the specified entity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_id |
str |
primary key value of the entity |
required |
requested_slots |
dict |
slot-value mapping of constraints |
required |
Source code in adviser/utils/domain/jsonlookupdomain.py
def find_info_about_entity(self, entity_id, requested_slots: Iterable):
""" Returns the values (stored in the data backend) of the specified slots for the
specified entity.
Args:
entity_id (str): primary key value of the entity
requested_slots (dict): slot-value mapping of constraints
"""
if requested_slots:
select_clause = ", ".join(sorted(requested_slots))
# If the user hasn't specified any slots we don't know what they want so we give everything
else:
select_clause = "*"
query = 'SELECT {} FROM {} WHERE {}="{}";'.format(
select_clause, self.get_domain_name(), self.get_primary_key(), entity_id)
return self.query_db(query)
get_display_name(self)
¶
get_informable_slots(self)
¶
get_keyword(self)
¶
get_possible_values(self, slot)
¶
Returns all possible values for an informable slot
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
name of the slot |
required |
Returns:
Type | Description |
---|---|
List[str] |
a list of strings, each string representing one possible value for the specified slot. |
Source code in adviser/utils/domain/jsonlookupdomain.py
get_primary_key(self)
¶
Returns the name of a column in the associated database which can be used to uniquely distinguish between database entities. Could be e.g. the name of a restaurant, an ID, ...
get_pronouns(self, slot)
¶
get_requestable_slots(self)
¶
get_system_requestable_slots(self)
¶
query_db(self, query_str)
¶
Function for querying the sqlite3 db
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query_str |
string |
sqlite3 query style string |
required |
Returns:
Type | Description |
---|---|
(iterable) |
rows of the query response set |
Source code in adviser/utils/domain/jsonlookupdomain.py
def query_db(self, query_str):
""" Function for querying the sqlite3 db
Args:
query_str (string): sqlite3 query style string
Return:
(iterable): rows of the query response set
"""
if "db" not in self.__dict__:
root_dir = self._get_root_dir()
sqllite_db_file = self.sqllite_db_file or os.path.join(
'resources', 'databases', self.name + '.db')
self.db = self._load_db_to_memory(root_dir + '/' + sqllite_db_file)
cursor = self.db.cursor()
cursor.execute(query_str)
res = cursor.fetchall()
return res
lookupdomain
¶
LookupDomain (Domain)
¶
Abstract class for linking a domain with a data access method.
Derive from this class if you need to implement a domain with a not yet supported data backend, otherwise choose a fitting existing child class.
Source code in adviser/utils/domain/lookupdomain.py
class LookupDomain(Domain):
""" Abstract class for linking a domain with a data access method.
Derive from this class if you need to implement a domain with a not yet
supported data backend, otherwise choose a fitting existing child class. """
def __init__(self, identifier : str, display_name : str):
Domain.__init__(self, identifier)
self.display_name = display_name
def find_entities(self, constraints : dict, requested_slots: Iterable = iter(())):
""" Returns all entities from the data backend that meet the constraints.
Args:
constraints (dict): slot-value mapping of constraints
IMPORTANT: This function must be overridden!
"""
raise NotImplementedError
def find_info_about_entity(self, entity_id, requested_slots: Iterable):
""" Returns the values (stored in the data backend) of the specified slots for the
specified entity.
Args:
entity_id (str): primary key value of the entity
requested_slots (dict): slot-value mapping of constraints
"""
raise NotImplementedError
def get_display_name(self):
return self.display_name
def get_requestable_slots(self) -> List[str]:
""" Returns a list of all slots requestable by the user. """
raise NotImplementedError
def get_system_requestable_slots(self) -> List[str]:
""" Returns a list of all slots requestable by the system. """
raise NotImplementedError
def get_informable_slots(self) -> List[str]:
""" Returns a list of all informable slots. """
raise NotImplementedError
def get_mandatory_slots(self) -> List[str]:
"""Returns a list of all mandatory slots.
Slots are called mandatory if their value is required by the system before it can even
generate a candidate list.
"""
raise NotImplementedError
def get_default_inform_slots(self) -> List[str]:
"""Returns a list of all default Inform slots.
Default Inform slots are always added to (system) Inform actions, even if the user has not
implicitly asked for it. Note that these slots are different from the primary key slot.
"""
raise NotImplementedError
def get_possible_values(self, slot: str) -> List[str]:
""" Returns all possible values for an informable slot
Args:
slot (str): name of the slot
Returns:
a list of strings, each string representing one possible value for
the specified slot.
"""
raise NotImplementedError
def get_primary_key(self) -> str:
""" Returns the slot name that will be used as the 'name' of an entry """
raise NotImplementedError
def get_keyword(self):
raise NotImplementedError
__init__(self, identifier, display_name)
special
¶
find_entities(self, constraints, requested_slots=<tuple_iterator object at 0x7f1f8376f8b0>)
¶
Returns all entities from the data backend that meet the constraints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constraints |
dict |
slot-value mapping of constraints |
required |
IMPORTANT: This function must be overridden!
Source code in adviser/utils/domain/lookupdomain.py
find_info_about_entity(self, entity_id, requested_slots)
¶
Returns the values (stored in the data backend) of the specified slots for the specified entity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
entity_id |
str |
primary key value of the entity |
required |
requested_slots |
dict |
slot-value mapping of constraints |
required |
Source code in adviser/utils/domain/lookupdomain.py
def find_info_about_entity(self, entity_id, requested_slots: Iterable):
""" Returns the values (stored in the data backend) of the specified slots for the
specified entity.
Args:
entity_id (str): primary key value of the entity
requested_slots (dict): slot-value mapping of constraints
"""
raise NotImplementedError
get_default_inform_slots(self)
¶
Returns a list of all default Inform slots.
Default Inform slots are always added to (system) Inform actions, even if the user has not implicitly asked for it. Note that these slots are different from the primary key slot.
Source code in adviser/utils/domain/lookupdomain.py
def get_default_inform_slots(self) -> List[str]:
"""Returns a list of all default Inform slots.
Default Inform slots are always added to (system) Inform actions, even if the user has not
implicitly asked for it. Note that these slots are different from the primary key slot.
"""
raise NotImplementedError
get_display_name(self)
¶
get_informable_slots(self)
¶
get_keyword(self)
¶
get_mandatory_slots(self)
¶
Returns a list of all mandatory slots.
Slots are called mandatory if their value is required by the system before it can even generate a candidate list.
get_possible_values(self, slot)
¶
Returns all possible values for an informable slot
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
name of the slot |
required |
Returns:
Type | Description |
---|---|
List[str] |
a list of strings, each string representing one possible value for the specified slot. |
Source code in adviser/utils/domain/lookupdomain.py
get_primary_key(self)
¶
get_requestable_slots(self)
¶
get_system_requestable_slots(self)
¶
logger
¶
This module provides a logger for configurable output on different levels.
DiasysLogger (Logger)
¶
Logger class.
This class enables logging to both a logfile and the console with different information levels. It also provides logging methods for the newly introduced information levels (LogLevel.DIALOGS and LogLevel.RESULTS).
If file_level is set to LogLevel.NONE, no log file will be created. Otherwise, the output directory can be configured by setting log_folder.
Source code in adviser/utils/logger.py
class DiasysLogger(logging.Logger):
"""Logger class.
This class enables logging to both a logfile and the console with different
information levels.
It also provides logging methods for the newly introduced information
levels (LogLevel.DIALOGS and LogLevel.RESULTS).
If file_level is set to LogLevel.NONE, no log file will be created.
Otherwise, the output directory can be configured by setting log_folder.
"""
def __init__(self, name: str = 'adviser', console_log_lvl: LogLevel = LogLevel.ERRORS,
file_log_lvl: LogLevel = LogLevel.NONE, logfile_folder: str = 'logs',
logfile_basename: str = 'log'): # pylint: disable=too-many-arguments
super(DiasysLogger, self).__init__(name)
if file_log_lvl is not LogLevel.NONE:
# configure output to log file
os.makedirs(os.path.realpath(logfile_folder), exist_ok=True)
log_file_name = logfile_basename + '_' + \
str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) + '.log'
log_file_path = os.path.join(os.path.realpath(logfile_folder), log_file_name)
file_handler = logging.FileHandler(log_file_path, mode='w')
file_handler.setLevel(int(file_log_lvl))
fh_formatter = MultilineFormatter('%(asctime)s - %(message)s')
file_handler.setFormatter(fh_formatter)
self.addHandler(file_handler)
# configure output to console
console_handler = logging.StreamHandler()
console_handler.setLevel(int(console_log_lvl))
# ch_formatter = MultilineFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch_formatter = MultilineFormatter('logger: %(message)s')
console_handler.setFormatter(ch_formatter)
self.addHandler(console_handler)
# log exceptions
sys.excepthook = exception_logging_hook
def result(self, msg: str):
""" Logs the result of a dialog """
self.log(int(LogLevel.RESULTS), msg)
def dialog_turn(self, msg: str, dialog_act=None):
""" Logs a turn of a dialog """
log_msg = msg
if dialog_act is not None:
log_msg += "\n " + str(dialog_act)
self.log(int(LogLevel.DIALOGS), log_msg)
__init__(self, name='adviser', console_log_lvl=<LogLevel.ERRORS: 40>, file_log_lvl=<LogLevel.NONE: 100>, logfile_folder='logs', logfile_basename='log')
special
¶
Source code in adviser/utils/logger.py
def __init__(self, name: str = 'adviser', console_log_lvl: LogLevel = LogLevel.ERRORS,
file_log_lvl: LogLevel = LogLevel.NONE, logfile_folder: str = 'logs',
logfile_basename: str = 'log'): # pylint: disable=too-many-arguments
super(DiasysLogger, self).__init__(name)
if file_log_lvl is not LogLevel.NONE:
# configure output to log file
os.makedirs(os.path.realpath(logfile_folder), exist_ok=True)
log_file_name = logfile_basename + '_' + \
str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) + '.log'
log_file_path = os.path.join(os.path.realpath(logfile_folder), log_file_name)
file_handler = logging.FileHandler(log_file_path, mode='w')
file_handler.setLevel(int(file_log_lvl))
fh_formatter = MultilineFormatter('%(asctime)s - %(message)s')
file_handler.setFormatter(fh_formatter)
self.addHandler(file_handler)
# configure output to console
console_handler = logging.StreamHandler()
console_handler.setLevel(int(console_log_lvl))
# ch_formatter = MultilineFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch_formatter = MultilineFormatter('logger: %(message)s')
console_handler.setFormatter(ch_formatter)
self.addHandler(console_handler)
# log exceptions
sys.excepthook = exception_logging_hook
dialog_turn(self, msg, dialog_act=None)
¶
result(self, msg)
¶
LogLevel (IntEnum)
¶
MultilineFormatter (Formatter)
¶
A formatter for the logger taking care of multiline messages.
Source code in adviser/utils/logger.py
class MultilineFormatter(logging.Formatter):
""" A formatter for the logger taking care of multiline messages. """
def format(self, record: logging.LogRecord):
save_msg = record.msg
output = ""
for idx, line in enumerate(save_msg.splitlines()):
if idx > 0:
output += "\n"
record.msg = line
output += super().format(record)
record.msg = save_msg
record.message = output
return output
format(self, record)
¶
Format the specified record as text.
The record's attribute dictionary is used as the operand to a string formatting operation which yields the returned string. Before formatting the dictionary, a couple of preparatory steps are carried out. The message attribute of the record is computed using LogRecord.getMessage(). If the formatting string uses the time (as determined by a call to usesTime(), formatTime() is called to format the event time. If there is exception information, it is formatted using formatException() and appended to the message.
Source code in adviser/utils/logger.py
exception_logging_hook(exc_type, exc_value, exc_traceback)
¶
sysact
¶
This module provides the necessary classes for a system action.
SysAct
¶
Source code in adviser/utils/sysact.py
class SysAct(object):
def __init__(self, act_type: SysActionType = None, slot_values: Dict[str, List[str]] = None):
"""
The class for a system action as used in the dialog.
Args:
act_type (SysActionType): The type of the system action.
slot_values (Dict[str, List[str]]): A mapping of ``slot -> value`` to which the system
action refers depending on the action type - might be ``None``.
"""
self.type = act_type
self.slot_values = slot_values if slot_values is not None else {}
def __repr__(self):
return f"""SysAct(act_type={self.type}
{f", {self._slot_value_dict_to_str(self.slot_values)}"
if self._slot_value_dict_to_str(self.slot_values) else ""})"""
def __str__(self):
if self.type is not None:
return self.type.value + \
'(' + \
self._slot_value_dict_to_str(self.slot_values) + \
')'
else:
return 'SysAct(act_type=' + self.type + ', ' + \
self._slot_value_dict_to_str(self.slot_values) + \
')'
def add_value(self, slot: str, value=None):
""" Add a value (or just a slot, if value=None) to the system act """
if slot not in self.slot_values:
self.slot_values[slot] = []
if value is not None:
self.slot_values[slot].append(value)
def get_values(self, slot) -> list:
""" Return all values for slot
Returns:
A list of values for slot or an empy list if there was no value
specified for the given slot
"""
if slot not in self.slot_values:
return []
else:
return self.slot_values[slot]
def __eq__(self, other):
return (self.type == other.type and
self.slot_values == other.slot_values)
def _slot_value_dict_to_str(self, slot_value_dict):
""" convert dictionary to slot1=value1, slot2=value2, ... string """
stringrep = []
for slot in slot_value_dict:
if slot_value_dict[slot]:
if isinstance(slot_value_dict, list):
# there are values specified for slot, add them
for value in slot_value_dict[slot]:
if value is not None:
stringrep.append('{}="{}"'.format(slot, value))
else:
if slot_value_dict[slot] is not None:
stringrep.append('{}="{}"'.format(slot, slot_value_dict[slot]))
else:
# slot without value
stringrep.append(slot)
return ','.join(stringrep)
__eq__(self, other)
special
¶
__init__(self, act_type=None, slot_values=None)
special
¶
The class for a system action as used in the dialog.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
act_type |
SysActionType |
The type of the system action. |
None |
slot_values |
Dict[str, List[str]] |
A mapping of |
None |
Source code in adviser/utils/sysact.py
def __init__(self, act_type: SysActionType = None, slot_values: Dict[str, List[str]] = None):
"""
The class for a system action as used in the dialog.
Args:
act_type (SysActionType): The type of the system action.
slot_values (Dict[str, List[str]]): A mapping of ``slot -> value`` to which the system
action refers depending on the action type - might be ``None``.
"""
self.type = act_type
self.slot_values = slot_values if slot_values is not None else {}
__repr__(self)
special
¶
__str__(self)
special
¶
add_value(self, slot, value=None)
¶
Add a value (or just a slot, if value=None) to the system act
get_values(self, slot)
¶
Return all values for slot
Returns:
Type | Description |
---|---|
list |
A list of values for slot or an empy list if there was no value specified for the given slot |
SysActionType (Enum)
¶
The type for a system action as used in :class:SysAct
.
Source code in adviser/utils/sysact.py
class SysActionType(Enum):
"""The type for a system action as used in :class:`SysAct`."""
Welcome = 'welcomemsg'
InformByName = 'inform_byname'
InformByAlternatives = 'inform_alternatives'
Request = 'request'
Confirm = 'confirm'
Select = 'select'
RequestMore = 'reqmore'
Bad = 'bad'
Bye = 'closingmsg'
ConfirmRequest = 'confreq'
topics
¶
Topic
¶
Source code in adviser/utils/topics.py
class Topic(object):
DIALOG_START = 'dialog_start' # Called at the beginning of a new dialog. Subscribe here to set stateful variables for one dialog.
DIALOG_END = 'dialog_end' # Called at the end of a dialog (after a bye-action).
DIALOG_EXIT = 'dialog_exit' # Called when the dialog system shuts down. Subscribe here if you e.g. have to close resource handles / free locks.
useract
¶
This module provides the necessary classes for a user action.
UserAct
¶
Source code in adviser/utils/useract.py
class UserAct(object):
def __init__(self, text: str = "", act_type: UserActionType = None, slot: str = None,
value: str = None, score: float = 1.0):
"""
The class for a user action as used in the dialog.
Args:
text (str): A textual representation of the user action.
act_type (UserActionType): The type of the user action.
slot (str): The slot to which the user action refers - might be ``None`` depending on the
user action.
value (str): The value to which the user action refers - might be ``None`` depending on the
user action.
score (float): A value from 0 (not important) to 1 (important) indicating how important
the information is for the belief state.
"""
self.text = text
self.type = act_type
self.slot = slot
self.value = value
self.score = score
def __repr__(self):
return "UserAct(\"{}\", {}, {}, {}, {})".format(
self.text, self.type, self.slot, self.value, self.score)
def __eq__(self, other): # to check for equality for tests
return (self.type == other.type and
self.slot == other.slot and
self.value == other.value and
self.score == other.score)
def __hash__(self):
return hash(self.type) * hash(self.slot) * hash(self.value) * hash(self.score)
__eq__(self, other)
special
¶
__hash__(self)
special
¶
__init__(self, text='', act_type=None, slot=None, value=None, score=1.0)
special
¶
The class for a user action as used in the dialog.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str |
A textual representation of the user action. |
'' |
act_type |
UserActionType |
The type of the user action. |
None |
slot |
str |
The slot to which the user action refers - might be |
None |
value |
str |
The value to which the user action refers - might be |
None |
score |
float |
A value from 0 (not important) to 1 (important) indicating how important the information is for the belief state. |
1.0 |
Source code in adviser/utils/useract.py
def __init__(self, text: str = "", act_type: UserActionType = None, slot: str = None,
value: str = None, score: float = 1.0):
"""
The class for a user action as used in the dialog.
Args:
text (str): A textual representation of the user action.
act_type (UserActionType): The type of the user action.
slot (str): The slot to which the user action refers - might be ``None`` depending on the
user action.
value (str): The value to which the user action refers - might be ``None`` depending on the
user action.
score (float): A value from 0 (not important) to 1 (important) indicating how important
the information is for the belief state.
"""
self.text = text
self.type = act_type
self.slot = slot
self.value = value
self.score = score
__repr__(self)
special
¶
UserActionType (Enum)
¶
The type for a user action as used in :class:UserAct
.
Source code in adviser/utils/useract.py
class UserActionType(Enum):
"""The type for a user action as used in :class:`UserAct`."""
Inform = 'inform'
NegativeInform = 'negativeinform'
Request = 'request'
Hello = 'hello'
Bye = 'bye'
Thanks = 'thanks'
Affirm = 'affirm'
Deny = 'deny'
RequestAlternatives = 'reqalts'
Ack = 'ack'
Bad = 'bad'
Confirm = 'confirm'
SelectDomain = 'selectdomain'
userstate
¶
This module provides the UserState class.
EmotionType (Enum)
¶
EngagementType (Enum)
¶
UserState
¶
The representation of a user state. Can be accessed like a dictionary
Source code in adviser/utils/userstate.py
class UserState:
"""
The representation of a user state.
Can be accessed like a dictionary
"""
def __init__(self):
self._history = [self._init_userstate()]
def __getitem__(self, val): # for indexing
# if used with numbers: int (e.g. state[-2]) or slice (e.g. state[3:6])
if isinstance(val, int) or isinstance(val, slice):
return self._history[val] # interpret the number as turn
# if used with strings (e.g. state['beliefs'])
elif isinstance(val, str):
# take the current turn's belief state
return self._history[-1][val]
def __iter__(self):
return iter(self._history[-1])
def __setitem__(self, key, val):
self._history[-1][key] = val
def __len__(self):
return len(self._history)
def __contains__(self, val): # assume
return val in self._history[-1]
def __repr__(self):
return str(self._history[-1])
def start_new_turn(self):
"""
ONLY to be called by the user state tracker at the begin of each turn,
to ensure the correct history can be accessed correctly by other modules
"""
# copy last turn's dict
self._history.append(copy.deepcopy(self._history[-1]))
def _init_userstate(self):
"""Initializes the user state based on the currently active domain
Returns:
(dict): dictionary of user emotion and engagement representations
"""
# TODO: revist when we include probabilites, sets should become dictionaries
user_state = {"engagement": EngagementType.Low,
"emotion": EmotionType.Neutral}
return user_state
__contains__(self, val)
special
¶
__getitem__(self, val)
special
¶
Source code in adviser/utils/userstate.py
def __getitem__(self, val): # for indexing
# if used with numbers: int (e.g. state[-2]) or slice (e.g. state[3:6])
if isinstance(val, int) or isinstance(val, slice):
return self._history[val] # interpret the number as turn
# if used with strings (e.g. state['beliefs'])
elif isinstance(val, str):
# take the current turn's belief state
return self._history[-1][val]
__init__(self)
special
¶
__iter__(self)
special
¶
__len__(self)
special
¶
__repr__(self)
special
¶
__setitem__(self, key, val)
special
¶
start_new_turn(self)
¶
ONLY to be called by the user state tracker at the begin of each turn, to ensure the correct history can be accessed correctly by other modules