Services¶
backchannel
special
¶
PytorchAcousticBackchanneler
¶
PytorchAcousticBackchanneler (Module)
¶
Class for defining the Deep Backchannel model in PyTorch
Source code in adviser/services/backchannel/PytorchAcousticBackchanneler.py
class PytorchAcousticBackchanneler(nn.Module):
"""Class for defining the Deep Backchannel model in PyTorch"""
def __init__(self, parameters:list=[], load_params:bool=False):
"""
Defines the elements/layers of the neural network as well as loads the pretrained parameters
The model is constituted by two parallel CNNs followed by a concatenation, a FFN and a softmax layer.
Args:
parameters (list): list of pre-trained parameters to be used for prediction
load_params (bool): Bool to signal if params should be loaded
"""
super(PytorchAcousticBackchanneler, self).__init__()
# First CNN
cnn = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(11, 13), stride=(3,1))
if load_params:
weights = np.transpose(parameters[0][0], (3, 2, 0, 1))
cnn.weight = torch.nn.Parameter(torch.tensor(weights).float())
cnn.bias = torch.nn.Parameter(torch.tensor(parameters[0][1]).float())
self.cnn1 = nn.Sequential(
cnn,
nn.ReLU(),
nn.MaxPool2d((23, 1))
)
# Second CNN
cnn = nn.Conv2d(in_channels=1, out_channels=16, kernel_size = (12, 13), stride=(3,1))
if load_params:
weights = np.transpose(parameters[1][0], (3,2,0,1))
cnn.weight = torch.nn.Parameter(torch.tensor(weights).float())
cnn.bias = torch.nn.Parameter(torch.tensor(parameters[1][1]).float())
self.cnn2 = nn.Sequential(
cnn,
nn.ReLU(),
nn.MaxPool2d((23, 1))
)
# Linear layer
self.linear1 = nn.Linear(in_features=64, out_features=100)
if load_params:
self.linear1.weight = torch.nn.Parameter(torch.tensor(parameters[2][0].T).float())
self.linear1.bias = torch.nn.Parameter(torch.tensor(parameters[2][1]).float())
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
# Softmax
self.linear2 = nn.Linear(in_features=100, out_features=3)
if load_params:
self.linear2.weight = torch.nn.Parameter(torch.tensor(parameters[3][0].T).float())
self.linear2.bias = torch.nn.Parameter(torch.tensor(parameters[3][1]).float())
self.softmax = nn.Softmax(dim=1)
def forward(self, feat_inputs):
"""
PyTorch forward method used for training and prediction. It defines the interaction between layers.
Args:
feat_inputs (numpy array): It contains the network's input.
Returns:
out (torch.tensor): Network's output
"""
feat_inputs = torch.tensor(feat_inputs).float()
feat_inputs = feat_inputs.unsqueeze(1)
cnn_1 = self.cnn1(feat_inputs)
cnn_1 = cnn_1.flatten(1)
cnn_2 = self.cnn2(feat_inputs).flatten(1)
out = torch.cat((cnn_1, cnn_2), 1)
out = self.linear1(out)
out = self.relu(out)
out = self.dropout(out)
out = self.linear2(out)
out = self.softmax(out)
return out
__init__(self, parameters=[], load_params=False)
special
¶
Defines the elements/layers of the neural network as well as loads the pretrained parameters
The model is constituted by two parallel CNNs followed by a concatenation, a FFN and a softmax layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters |
list |
list of pre-trained parameters to be used for prediction |
[] |
load_params |
bool |
Bool to signal if params should be loaded |
False |
Source code in adviser/services/backchannel/PytorchAcousticBackchanneler.py
def __init__(self, parameters:list=[], load_params:bool=False):
"""
Defines the elements/layers of the neural network as well as loads the pretrained parameters
The model is constituted by two parallel CNNs followed by a concatenation, a FFN and a softmax layer.
Args:
parameters (list): list of pre-trained parameters to be used for prediction
load_params (bool): Bool to signal if params should be loaded
"""
super(PytorchAcousticBackchanneler, self).__init__()
# First CNN
cnn = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(11, 13), stride=(3,1))
if load_params:
weights = np.transpose(parameters[0][0], (3, 2, 0, 1))
cnn.weight = torch.nn.Parameter(torch.tensor(weights).float())
cnn.bias = torch.nn.Parameter(torch.tensor(parameters[0][1]).float())
self.cnn1 = nn.Sequential(
cnn,
nn.ReLU(),
nn.MaxPool2d((23, 1))
)
# Second CNN
cnn = nn.Conv2d(in_channels=1, out_channels=16, kernel_size = (12, 13), stride=(3,1))
if load_params:
weights = np.transpose(parameters[1][0], (3,2,0,1))
cnn.weight = torch.nn.Parameter(torch.tensor(weights).float())
cnn.bias = torch.nn.Parameter(torch.tensor(parameters[1][1]).float())
self.cnn2 = nn.Sequential(
cnn,
nn.ReLU(),
nn.MaxPool2d((23, 1))
)
# Linear layer
self.linear1 = nn.Linear(in_features=64, out_features=100)
if load_params:
self.linear1.weight = torch.nn.Parameter(torch.tensor(parameters[2][0].T).float())
self.linear1.bias = torch.nn.Parameter(torch.tensor(parameters[2][1]).float())
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
# Softmax
self.linear2 = nn.Linear(in_features=100, out_features=3)
if load_params:
self.linear2.weight = torch.nn.Parameter(torch.tensor(parameters[3][0].T).float())
self.linear2.bias = torch.nn.Parameter(torch.tensor(parameters[3][1]).float())
self.softmax = nn.Softmax(dim=1)
forward(self, feat_inputs)
¶
PyTorch forward method used for training and prediction. It defines the interaction between layers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
feat_inputs |
numpy array |
It contains the network's input. |
required |
Returns:
Type | Description |
---|---|
out (torch.tensor) |
Network's output |
Source code in adviser/services/backchannel/PytorchAcousticBackchanneler.py
def forward(self, feat_inputs):
"""
PyTorch forward method used for training and prediction. It defines the interaction between layers.
Args:
feat_inputs (numpy array): It contains the network's input.
Returns:
out (torch.tensor): Network's output
"""
feat_inputs = torch.tensor(feat_inputs).float()
feat_inputs = feat_inputs.unsqueeze(1)
cnn_1 = self.cnn1(feat_inputs)
cnn_1 = cnn_1.flatten(1)
cnn_2 = self.cnn2(feat_inputs).flatten(1)
out = torch.cat((cnn_1, cnn_2), 1)
out = self.linear1(out)
out = self.relu(out)
out = self.dropout(out)
out = self.linear2(out)
out = self.softmax(out)
return out
acoustic_backchanneller
¶
gpu
¶
AcousticBackchanneller (Service)
¶
AcousticBackchanneller predicts a backchannel given the last user utterance. The model can predict: No backchannel (0), Assessment (1), Continuer (2) The backchannel realization is added in the NLG module.
Source code in adviser/services/backchannel/acoustic_backchanneller.py
class AcousticBackchanneller(Service):
"""AcousticBackchanneller predicts a backchannel given the last user utterance.
The model can predict: No backchannel (0), Assessment (1), Continuer (2)
The backchannel realization is added in the NLG module.
"""
def __init__(self):
Service.__init__(self)
self.speech_in_dir = os.path.dirname(os.path.abspath(__file__)) + '/'
self.trained_model_path = os.path.join('resources', 'models', 'backchannel') + '/pytorch_acoustic_backchanneller.pt'
self.load_model()
def load_model(self):
"""
The PyTorch Backchannel model is instantiated and the pretrained parameters are loaded.
Returns:
"""
self.model = PytorchAcousticBackchanneler()
self.model.load_state_dict(torch.load(self.trained_model_path))
self.model.eval()
def split_input_data(self, mfcc_features):
"""
Preprocess and segmentation of MFCC features of the user's speech.
Segmentation is done every 150ms without overlapping.
Args:
mfcc_features (numpy.array): mffcc features of users speech
Returns:
new_data (list): segmented mfcc features
"""
input_height = 150 # this stands for 150ms
input_length = mfcc_features.shape[0]
zero_shape = list(mfcc_features.shape)
zero_shape[0] = input_height
ranges = list(reversed([idx for idx in range(input_length - 1, 0, -input_height)]))
new_data = []
for r in ranges:
if r < input_height:
zero_data = np.zeros(zero_shape)
zero_data[-r:, :] = mfcc_features[:r, :]
new_data.append(zero_data)
else:
new_data.append(mfcc_features[r - input_height:r, :])
return (new_data)
@PublishSubscribe(sub_topics=['mfcc'],
pub_topics=["predicted_BC"])
def backchannel_prediction(self, mfcc: np.array):
"""
Service that receives the MFCC features from the user's speech.
It preprocess and normalize them and makes the BC prediction.
Args:
mfcc_features (torch.tensor): MFCC features
Returns:
(dict): a dictionary with the key "predicted_BC" and the value of the BC type
"""
# class_int_mapping = {0: b'no_bc', 1: b'assessment', 2: b'continuer'}
mfcc_features = mfcc.numpy()
scaler = preprocessing.StandardScaler()
mfcc_features = scaler.fit_transform(mfcc_features)
input_splits = self.split_input_data(mfcc_features)
prediction = self.model(input_splits).detach().numpy().argmax(axis=1)
# Returning the majority, unless a BC appears,
if len(set(prediction)) == 1:
return {'predicted_BC': prediction[0]}
elif 1 in prediction and 2 in prediction:
ones = len(prediction[prediction==1])
twos = len(prediction[prediction==2])
return {'predicted_BC': 1 if ones > twos else 2}
else:
return {'predicted_BC': 1 if 1 in prediction else 2}
__init__(self)
special
¶
Source code in adviser/services/backchannel/acoustic_backchanneller.py
backchannel_prediction(self, *args, **kwargs)
¶
Source code in adviser/services/backchannel/acoustic_backchanneller.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
load_model(self)
¶
The PyTorch Backchannel model is instantiated and the pretrained parameters are loaded.
Source code in adviser/services/backchannel/acoustic_backchanneller.py
split_input_data(self, mfcc_features)
¶
Preprocess and segmentation of MFCC features of the user's speech. Segmentation is done every 150ms without overlapping.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mfcc_features |
numpy.array |
mffcc features of users speech |
required |
Returns:
Type | Description |
---|---|
new_data (list) |
segmented mfcc features |
Source code in adviser/services/backchannel/acoustic_backchanneller.py
def split_input_data(self, mfcc_features):
"""
Preprocess and segmentation of MFCC features of the user's speech.
Segmentation is done every 150ms without overlapping.
Args:
mfcc_features (numpy.array): mffcc features of users speech
Returns:
new_data (list): segmented mfcc features
"""
input_height = 150 # this stands for 150ms
input_length = mfcc_features.shape[0]
zero_shape = list(mfcc_features.shape)
zero_shape[0] = input_height
ranges = list(reversed([idx for idx in range(input_length - 1, 0, -input_height)]))
new_data = []
for r in ranges:
if r < input_height:
zero_data = np.zeros(zero_shape)
zero_data[-r:, :] = mfcc_features[:r, :]
new_data.append(zero_data)
else:
new_data.append(mfcc_features[r - input_height:r, :])
return (new_data)
bst
special
¶
bst
¶
HandcraftedBST (Service)
¶
A rule-based approach to belief state tracking.
Source code in adviser/services/bst/bst.py
class HandcraftedBST(Service):
"""
A rule-based approach to belief state tracking.
"""
def __init__(self, domain=None, logger=None):
Service.__init__(self, domain=domain)
self.logger = logger
self.bs = BeliefState(domain)
@PublishSubscribe(sub_topics=["user_acts"], pub_topics=["beliefstate"])
def update_bst(self, user_acts: List[UserAct] = None) \
-> dict(beliefstate=BeliefState):
"""
Updates the current dialog belief state (which tracks the system's
knowledge about what has been said in the dialog) based on the user actions generated
from the user's utterances
Args:
user_acts (list): a list of UserAct objects mapped from the user's last utterance
Returns:
(dict): a dictionary with the key "beliefstate" and the value the updated
BeliefState object
"""
# save last turn to memory
self.bs.start_new_turn()
if user_acts:
self._reset_informs(user_acts)
self._reset_requests()
self.bs["user_acts"] = self._get_all_usr_action_types(user_acts)
self._handle_user_acts(user_acts)
num_entries, discriminable = self.bs.get_num_dbmatches()
self.bs["num_matches"] = num_entries
self.bs["discriminable"] = discriminable
return {'beliefstate': self.bs}
def dialog_start(self):
"""
Restets the belief state so it is ready for a new dialog
Returns:
(dict): a dictionary with a single entry where the key is 'beliefstate'and
the value is a new BeliefState object
"""
# initialize belief state
self.bs = BeliefState(self.domain)
def _reset_informs(self, acts: List[UserAct]):
"""
If the user specifies a new value for a given slot, delete the old
entry from the beliefstate
"""
slots = {act.slot for act in acts if act.type == UserActionType.Inform}
for slot in [s for s in self.bs['informs']]:
if slot in slots:
del self.bs['informs'][slot]
def _reset_requests(self):
"""
gets rid of requests from the previous turn
"""
self.bs['requests'] = {}
def _get_all_usr_action_types(self, user_acts: List[UserAct]) -> Set[UserActionType]:
"""
Returns a set of all different UserActionTypes in user_acts.
Args:
user_acts (List[UserAct]): list of UserAct objects
Returns:
set of UserActionType objects
"""
action_type_set = set()
for act in user_acts:
action_type_set.add(act.type)
return action_type_set
def _handle_user_acts(self, user_acts: List[UserAct]):
"""
Updates the belief state based on the information contained in the user act(s)
Args:
user_acts (list[UserAct]): the list of user acts to use to update the belief state
"""
# reset any offers if the user informs any new information
if self.domain.get_primary_key() in self.bs['informs'] \
and UserActionType.Inform in self.bs["user_acts"]:
del self.bs['informs'][self.domain.get_primary_key()]
# We choose to interpret switching as wanting to start a new dialog and do not support
# resuming an old dialog
elif UserActionType.SelectDomain in self.bs["user_acts"]:
self.bs["informs"] = {}
self.bs["requests"] = {}
# Handle user acts
for act in user_acts:
if act.type == UserActionType.Request:
self.bs['requests'][act.slot] = act.score
elif act.type == UserActionType.Inform:
# add informs and their scores to the beliefstate
if act.slot in self.bs["informs"]:
self.bs['informs'][act.slot][act.value] = act.score
else:
self.bs['informs'][act.slot] = {act.value: act.score}
elif act.type == UserActionType.NegativeInform:
# reset mentioned value to zero probability
if act.slot in self.bs['informs']:
if act.value in self.bs['informs'][act.slot]:
del self.bs['informs'][act.slot][act.value]
elif act.type == UserActionType.RequestAlternatives:
# This way it is clear that the user is no longer asking about that one item
if self.domain.get_primary_key() in self.bs['informs']:
del self.bs['informs'][self.domain.get_primary_key()]
__init__(self, domain=None, logger=None)
special
¶
dialog_start(self)
¶
Restets the belief state so it is ready for a new dialog
Returns:
Type | Description |
---|---|
(dict) |
a dictionary with a single entry where the key is 'beliefstate'and the value is a new BeliefState object |
Source code in adviser/services/bst/bst.py
update_bst(self, *args, **kwargs)
¶
Source code in adviser/services/bst/bst.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
domain_tracker
special
¶
domain_tracker
¶
The console module provides ADVISER services for tracking current domain
DomainTracker (Service)
¶
Responsible for selecting which domain should be active at a given time. Current implmentation uses keywords to switch domains.
Source code in adviser/services/domain_tracker/domain_tracker.py
class DomainTracker(Service):
"""
Responsible for selecting which domain should be active at a given time.
Current implmentation uses keywords to switch domains.
"""
def __init__(self, domains: List[Domain], greet_on_first_turn: bool = False):
Service.__init__(self, domain="")
self.domains = domains
self.current_domain = None
self.greet_on_first_turn = greet_on_first_turn
def dialog_start(self):
"""
Resets the domain tracker for the start of a new dialog
"""
self.turn = 0
self.current_domain = None
@PublishSubscribe(sub_topics=["gen_user_utterance"], pub_topics=["user_utterance", "sys_utterance"])
def select_domain(self, gen_user_utterance: str = None) -> dict(user_utterance=str):
"""
Determines which domain should currently be active. In general, if a keyword is mentioned, the domain
will change, otherwise it is assumed that the previous domain is still active.
Args:
gen_user_utterance (str): the user utterance, before a domain has been determined
Returns:
(dict): A dictionary with "user_utterane" as a key and a string as the value with the
selected domain appended to the end so the message can be properly routed.
"""
self.turn += 1
if self.turn == 1 and self.greet_on_first_turn:
return {'sys_utterance': "Hello, please let me know how I can help you, I can discuss " +
f"the following domains: {self.domains_to_str()}."}
# if there is only a single domain, simply route the message forward
if len(self.domains) == 1:
self.current_domain = self.domains[0]
# make sure the utterance is lowercase if there is one
user_utterance = gen_user_utterance
if user_utterance:
user_utterance = gen_user_utterance.strip().lower()
# perform keyword matching to see if any domains are explicitely made active
active_domains = [d for d in self.domains if d.get_keyword() in user_utterance]
# Even if no domain has been specified, we should be able to exit
if "bye" in user_utterance and not self.current_domain:
return {"sys_utterance": "Thank you, goodbye."}
# if there are active domains, use the first one
elif active_domains:
out_key = f"user_utterance/{active_domains[0].get_domain_name()}"
self.current_domain = active_domains[0]
return {out_key: user_utterance}
# if no domain is explicitely mentioned, assume the last one is still active
elif self.current_domain:
out_key = f"user_utterance/{self.current_domain.get_domain_name()}"
return {out_key: user_utterance}
# Otherwise ask the user what domain they want
else:
return {"sys_utterance": "Hello, please let me know how I can help you, I can discuss " +
f"the following domains: {self.domains_to_str()}."}
def domains_to_str(self):
"""
Method to create the greeting on the first turn, grammatically joins the names of possible domains into
a string
Returns:
(str): String representing a list of all domain names the system can talk about
"""
if len(self.domains) == 1:
return self.domains[0].get_display_name()
elif len(self.domains) == 2:
return " and ".join([d.get_display_name() for d in self.domains])
else:
return ", ".join([d.get_display_name() for d in self.domains][:-1]) + f", and {self.domains[-1].get_display_name()}"
__init__(self, domains, greet_on_first_turn=False)
special
¶
dialog_start(self)
¶
domains_to_str(self)
¶
Method to create the greeting on the first turn, grammatically joins the names of possible domains into a string
Returns:
Type | Description |
---|---|
(str) |
String representing a list of all domain names the system can talk about |
Source code in adviser/services/domain_tracker/domain_tracker.py
def domains_to_str(self):
"""
Method to create the greeting on the first turn, grammatically joins the names of possible domains into
a string
Returns:
(str): String representing a list of all domain names the system can talk about
"""
if len(self.domains) == 1:
return self.domains[0].get_display_name()
elif len(self.domains) == 2:
return " and ".join([d.get_display_name() for d in self.domains])
else:
return ", ".join([d.get_display_name() for d in self.domains][:-1]) + f", and {self.domains[-1].get_display_name()}"
select_domain(self, *args, **kwargs)
¶
Source code in adviser/services/domain_tracker/domain_tracker.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
emotion
special
¶
EmotionRecognition
¶
Emotion recognition module.
EmotionRecognition (Service)
¶
Emotion recognition module.
This module receives acoustic features, loads pretrained models and outputs predictions of emotional states. It can easily be extended/adapted to use different models and facial features in addition.
Source code in adviser/services/emotion/EmotionRecognition.py
class EmotionRecognition(Service):
"""Emotion recognition module.
This module receives acoustic features, loads pretrained models and outputs
predictions of emotional states. It can easily be extended/adapted to use
different models and facial features in addition.
"""
def __init__(self):
""" Emotion recognition module.
On initialization all necessary models are loaded.
"""
Service.__init__(self)
self.emotion_dir = os.path.dirname(os.path.abspath(__file__))
self.model_path = os.path.abspath(
os.path.join(
self.emotion_dir, "..", "..", "resources", "models", "emotion"
)
)
def load_args(emo_representation):
arg_dict = pickle.load(
open(os.path.join(
self.model_path, f'{emo_representation}_args.pkl'),
'rb')
)
return arg_dict
def load_model(emo_representation, arg_dict):
ARGS = arg_dict['args']
model = cnn(
kernel_size=(ARGS.height, arg_dict['D_in']),
D_out=arg_dict['D_out'],
args=ARGS
)
model.load_state_dict(
torch.load(
os.path.join(self.model_path,
f'{emo_representation}_model_params.pt'),
map_location=torch.device('cpu')
)
)
model.eval()
return model
self.emo_representations = ['category', 'arousal', 'valence']
self.models = {}
self.args = {}
for emo_representation in self.emo_representations:
self.args[emo_representation] = load_args(emo_representation)
self.models[emo_representation] = load_model(
emo_representation,
self.args[emo_representation]
)
self.arousal_mapping = {0: 'low', 1: 'medium', 2: 'high'}
self.valence_mapping = {0: 'negative', 1: 'neutral', 2: 'positive'}
self.category_mapping = {
0: EmotionType.Angry,
1: EmotionType.Happy,
2: EmotionType.Neutral,
3: EmotionType.Sad
}
@PublishSubscribe(sub_topics=["fbank"], pub_topics=["emotion"])
def predict_from_audio(self, fbank):
"""Emotion prediction from acoustic features.
Args:
fbank (torch.Tensor): feature array, shape (sequence, num_mel_bins)
Returns:
dict: nested dictionary containing all results, main key: 'emotion'
"""
def normalize_and_pad_features(features: torch.Tensor, seq_len, mean: torch.Tensor, std: torch.Tensor):
# normalize
features = (features - mean) / std
# cut or pad with zeros as necessary
features = torch.cat(
[features[:seq_len], # take feature data until :seq_len
features.new_zeros( # pad with zeros if seq_len > feature.size(0)
(seq_len - features.size(0)) if seq_len > features.size(0) else 0,
features.size(1))],
dim=0 # concatenate zeros in time dimension
)
return features
predictions = {}
for emo_representation in self.emo_representations:
seq_len = self.args[emo_representation]['args'].seq_length
mean = self.args[emo_representation]['norm_mean']
std = self.args[emo_representation]['norm_std']
# feature normalization and padding has to be done for each
# emotion representation individually because the means and
# standard (deviations) (and sequence length) can be different
features = normalize_and_pad_features(fbank, seq_len, torch.from_numpy(mean), torch.from_numpy(std))
predictions[emo_representation] = softmax(
self.models[emo_representation](features.unsqueeze(1)), dim=1
).detach().numpy()
arousal_level = self.arousal_mapping[np.argmax(predictions['arousal'])]
valence_level = self.valence_mapping[np.argmax(predictions['valence'])]
category_label = self.category_mapping[np.argmax(predictions['category'])]
return {'emotion': {'arousal': arousal_level,
'valence': valence_level,
'category': category_label,
'cateogry_probabilities':
np.around(predictions['category'], 2).reshape(-1)}}
__init__(self)
special
¶
Emotion recognition module.
On initialization all necessary models are loaded.
Source code in adviser/services/emotion/EmotionRecognition.py
def __init__(self):
""" Emotion recognition module.
On initialization all necessary models are loaded.
"""
Service.__init__(self)
self.emotion_dir = os.path.dirname(os.path.abspath(__file__))
self.model_path = os.path.abspath(
os.path.join(
self.emotion_dir, "..", "..", "resources", "models", "emotion"
)
)
def load_args(emo_representation):
arg_dict = pickle.load(
open(os.path.join(
self.model_path, f'{emo_representation}_args.pkl'),
'rb')
)
return arg_dict
def load_model(emo_representation, arg_dict):
ARGS = arg_dict['args']
model = cnn(
kernel_size=(ARGS.height, arg_dict['D_in']),
D_out=arg_dict['D_out'],
args=ARGS
)
model.load_state_dict(
torch.load(
os.path.join(self.model_path,
f'{emo_representation}_model_params.pt'),
map_location=torch.device('cpu')
)
)
model.eval()
return model
self.emo_representations = ['category', 'arousal', 'valence']
self.models = {}
self.args = {}
for emo_representation in self.emo_representations:
self.args[emo_representation] = load_args(emo_representation)
self.models[emo_representation] = load_model(
emo_representation,
self.args[emo_representation]
)
self.arousal_mapping = {0: 'low', 1: 'medium', 2: 'high'}
self.valence_mapping = {0: 'negative', 1: 'neutral', 2: 'positive'}
self.category_mapping = {
0: EmotionType.Angry,
1: EmotionType.Happy,
2: EmotionType.Neutral,
3: EmotionType.Sad
}
predict_from_audio(self, *args, **kwargs)
¶
Source code in adviser/services/emotion/EmotionRecognition.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
engagement
special
¶
engagement_tracker
¶
EngagementTracker (Service)
¶
Start feature extraction with OpenFace. Requires OpenFace to be installed - instructions can be found in tool/openface.txt
Source code in adviser/services/engagement/engagement_tracker.py
class EngagementTracker(Service):
"""
Start feature extraction with OpenFace.
Requires OpenFace to be installed - instructions can be found in tool/openface.txt
"""
def __init__(self, domain="", camera_id: int = 0, openface_port: int = 6004, delay: int = 2, identifier=None):
"""
Args:
camera_id: index of the camera you want to use (if you only have one camera: 0)
"""
Service.__init__(self, domain="", identifier=identifier)
self.camera_id = camera_id
self.openface_port = openface_port
self.openface_running = False
self.threshold = delay # provide number of seconds as parameter, one second = 15 frames
ctx = Context.instance()
self.openface_endpoint = ctx.socket(zmq.PAIR)
self.openface_endpoint.bind(f"tcp://127.0.0.1:{self.openface_port}")
startExtraction = f"{os.path.join(get_root_dir(), 'tools/OpenFace/build/bin/FaceLandmarkVidZMQ')} -device {self.camera_id} -port 6004" # todo config open face port
self.p_openface = subprocess.Popen(startExtraction.split(), stdout=subprocess.PIPE) # start OpenFace
self.extracting = False
self.extractor_thread = None
def dialog_start(self):
# Set openface to publishing mode and wait until it is ready
self.openface_endpoint.send(bytes(f"OPENFACE_START", encoding="ascii"))
self.extracting = False
while not self.extracting:
msg = self.openface_endpoint.recv() # receive started signal
msg = msg.decode("utf-8")
if msg == "OPENFACE_STARTED":
print("START EXTRACTION")
self.extracting = True
self.extractor_thread = Thread(target=self.publish_gaze_directions)
self.extractor_thread.start()
@PublishSubscribe(pub_topics=["engagement", "gaze_direction"])
def yield_gaze_direction(self, engagement: EngagementType, gaze_direction: Tuple[float, float]):
"""
This is a helper function for the continuous publishing of engagement features.
Call this function from a continuously running loop.
Returns:
engagement (EngagementType): high / low
gaze_direction (float, float): tuple of gaze-x-angle and gaze-y-angle
"""
return {"engagement": engagement, "gaze_direction": gaze_direction}
def publish_gaze_directions(self):
"""
Meant to be used in a thread.
Runs an inifinte loop polling features from OpenFace library, parsing them and extracting engagement features.
Calls `yield_gaze_direction` to publish the polled and processed engagement features.
"""
x_coordinates=[]
y_coordinates=[]
norm = 0.0 # center point of screen; should be close(r) to 0
looking = True
while self.extracting:
req = self.openface_endpoint.send(bytes(f"OPENFACE_PULL", encoding="ascii"))
msg = self.openface_endpoint.recv()
try:
msg = msg.decode("utf-8")
if msg == "OPENFACE_ENDED":
self.extracting = False
msg_data = json.loads(msg)
gaze_x = msg_data["gaze"]["angle"]["x"]
gaze_y = msg_data["gaze"]["angle"]["y"]
gaze_x = sqrt(gaze_x**2) # gaze_angle_x (left-right movement), square + root is done to yield only positive values
gaze_y = sqrt(gaze_y**2) # gaze_angle_y (up-down movement)
x_coordinates.append(gaze_x)
y_coordinates.append(gaze_y)
current = (len(x_coordinates))-1
if current > self.threshold:
previous_x = mean(x_coordinates[current-(self.threshold+1):current]) # obtain the average of previous frames
previous_y = mean(y_coordinates[current-(self.threshold+1):current])
difference_x = sqrt((norm - previous_x)**2) # compare current frame to average of previous frames
difference_y = sqrt((norm - previous_y)**2)
# print(difference_x, difference_y)
if difference_x < 0.15 and difference_y < 0.15: # check whether difference between current and previous frames exceeds certain threshold (regulates tolerance/strictness)
if looking != True:
looking = True
self.yield_gaze_direction(engagement=EngagementType.High, gaze_direction=(gaze_x, gaze_y))
else:
if looking != False:
looking = False
self.yield_gaze_direction(engagement=EngagementType.Low, gaze_direction=(gaze_x, gaze_y))
except:
# import traceback
# traceback.print_exc()
pass
def dialog_end(self):
# Set openface to non-publishing mode and wait until it is ready
self.openface_endpoint.send(bytes(f"OPENFACE_END", encoding="ascii"))
if self.extractor_thread:
self.extractor_thread.join()
def dialog_exit(self):
# close openface process
self.p_openface.kill()
__init__(self, domain='', camera_id=0, openface_port=6004, delay=2, identifier=None)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
camera_id |
int |
index of the camera you want to use (if you only have one camera: 0) |
0 |
Source code in adviser/services/engagement/engagement_tracker.py
def __init__(self, domain="", camera_id: int = 0, openface_port: int = 6004, delay: int = 2, identifier=None):
"""
Args:
camera_id: index of the camera you want to use (if you only have one camera: 0)
"""
Service.__init__(self, domain="", identifier=identifier)
self.camera_id = camera_id
self.openface_port = openface_port
self.openface_running = False
self.threshold = delay # provide number of seconds as parameter, one second = 15 frames
ctx = Context.instance()
self.openface_endpoint = ctx.socket(zmq.PAIR)
self.openface_endpoint.bind(f"tcp://127.0.0.1:{self.openface_port}")
startExtraction = f"{os.path.join(get_root_dir(), 'tools/OpenFace/build/bin/FaceLandmarkVidZMQ')} -device {self.camera_id} -port 6004" # todo config open face port
self.p_openface = subprocess.Popen(startExtraction.split(), stdout=subprocess.PIPE) # start OpenFace
self.extracting = False
self.extractor_thread = None
dialog_end(self)
¶
This function is called after a dialog ended (Topics.DIALOG_END message was received). You should overwrite this function to record dialog-level information.
dialog_exit(self)
¶
dialog_start(self)
¶
This function is called before the first message to a new dialog is published. You should overwrite this function to set/reset dialog-level variables.
Source code in adviser/services/engagement/engagement_tracker.py
def dialog_start(self):
# Set openface to publishing mode and wait until it is ready
self.openface_endpoint.send(bytes(f"OPENFACE_START", encoding="ascii"))
self.extracting = False
while not self.extracting:
msg = self.openface_endpoint.recv() # receive started signal
msg = msg.decode("utf-8")
if msg == "OPENFACE_STARTED":
print("START EXTRACTION")
self.extracting = True
self.extractor_thread = Thread(target=self.publish_gaze_directions)
self.extractor_thread.start()
publish_gaze_directions(self)
¶
Meant to be used in a thread.
Runs an inifinte loop polling features from OpenFace library, parsing them and extracting engagement features.
Calls yield_gaze_direction
to publish the polled and processed engagement features.
Source code in adviser/services/engagement/engagement_tracker.py
def publish_gaze_directions(self):
"""
Meant to be used in a thread.
Runs an inifinte loop polling features from OpenFace library, parsing them and extracting engagement features.
Calls `yield_gaze_direction` to publish the polled and processed engagement features.
"""
x_coordinates=[]
y_coordinates=[]
norm = 0.0 # center point of screen; should be close(r) to 0
looking = True
while self.extracting:
req = self.openface_endpoint.send(bytes(f"OPENFACE_PULL", encoding="ascii"))
msg = self.openface_endpoint.recv()
try:
msg = msg.decode("utf-8")
if msg == "OPENFACE_ENDED":
self.extracting = False
msg_data = json.loads(msg)
gaze_x = msg_data["gaze"]["angle"]["x"]
gaze_y = msg_data["gaze"]["angle"]["y"]
gaze_x = sqrt(gaze_x**2) # gaze_angle_x (left-right movement), square + root is done to yield only positive values
gaze_y = sqrt(gaze_y**2) # gaze_angle_y (up-down movement)
x_coordinates.append(gaze_x)
y_coordinates.append(gaze_y)
current = (len(x_coordinates))-1
if current > self.threshold:
previous_x = mean(x_coordinates[current-(self.threshold+1):current]) # obtain the average of previous frames
previous_y = mean(y_coordinates[current-(self.threshold+1):current])
difference_x = sqrt((norm - previous_x)**2) # compare current frame to average of previous frames
difference_y = sqrt((norm - previous_y)**2)
# print(difference_x, difference_y)
if difference_x < 0.15 and difference_y < 0.15: # check whether difference between current and previous frames exceeds certain threshold (regulates tolerance/strictness)
if looking != True:
looking = True
self.yield_gaze_direction(engagement=EngagementType.High, gaze_direction=(gaze_x, gaze_y))
else:
if looking != False:
looking = False
self.yield_gaze_direction(engagement=EngagementType.Low, gaze_direction=(gaze_x, gaze_y))
except:
# import traceback
# traceback.print_exc()
pass
yield_gaze_direction(self, *args, **kwargs)
¶
Source code in adviser/services/engagement/engagement_tracker.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
get_root_dir()
¶
hci
special
¶
__all__
special
¶
console
¶
The console module provides ADVISER modules that access the console for input and output.
ConsoleInput (Service)
¶
Gets the user utterance from the console.
Waits for the built-in input function to return a non-empty text.
Source code in adviser/services/hci/console.py
class ConsoleInput(Service):
"""
Gets the user utterance from the console.
Waits for the built-in input function to return a non-empty text.
"""
def __init__(self, domain: Domain = None, conversation_log_dir: str = None, language: Language = None):
Service.__init__(self, domain=domain)
# self.language = language
self.language = Language.ENGLISH
self.conversation_log_dir = conversation_log_dir
self.interaction_count = 0
# if self.language is None:
# self.language = self._set_language()
def dialog_start(self):
self.interaction_count = 0
@PublishSubscribe(sub_topics=[Topic.DIALOG_END], pub_topics=["gen_user_utterance"])
def get_user_input(self, dialog_end: bool = True) -> dict(user_utterance=str):
"""
If this function has not been called before, do not pass a message.
Otherwise, it blocks the application until the user has entered a
valid (i.e. non-empty) message in the console.
Returns:
dict: a dict containing the user utterance
"""
if dialog_end:
return
utterance = self._input()
# write into logging directory
if self.conversation_log_dir is not None:
with open(os.path.join(self.conversation_log_dir, (str(math.floor(time.time())) + "_user.txt")),
"w") as convo_log:
convo_log.write(utterance)
return {'gen_user_utterance': utterance}
def _input(self):
"Helper function for reading text input from the console"
utterance = ''
try:
sys.stdout.write('>>> ')
sys.stdout.flush()
line = sys.stdin.readline()
while line.strip() == '' and not getattr(self, '_dialog_system_parent').terminating():
line = sys.stdin.readline()
utterance = line
if getattr(self, '_dialog_system_parent').terminating():
sys.stdin.close()
return utterance
except:
return utterance
def _set_language(self) -> Language:
"""
asks the user to select the language of the system, returning the enum
representing their preference, or None if they don't give a recognized
input
"""
utterance = ""
print("Please select your language: English or German")
while utterance.strip() == "":
utterance = input(">>> ")
utterance = utterance.lower()
if utterance == 'e' or utterance == 'english':
return Language.ENGLISH
elif utterance == 'g' or utterance == 'german' or utterance == 'deutsch' \
or utterance == 'd':
return Language.GERMAN
else:
return None
__init__(self, domain=None, conversation_log_dir=None, language=None)
special
¶
Source code in adviser/services/hci/console.py
def __init__(self, domain: Domain = None, conversation_log_dir: str = None, language: Language = None):
Service.__init__(self, domain=domain)
# self.language = language
self.language = Language.ENGLISH
self.conversation_log_dir = conversation_log_dir
self.interaction_count = 0
# if self.language is None:
# self.language = self._set_language()
dialog_start(self)
¶
get_user_input(self, *args, **kwargs)
¶
Source code in adviser/services/hci/console.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
ConsoleOutput (Service)
¶
Writes the system utterance to the console.
Source code in adviser/services/hci/console.py
class ConsoleOutput(Service):
"""Writes the system utterance to the console."""
def __init__(self, domain: Domain = None):
Service.__init__(self, domain=domain)
@PublishSubscribe(sub_topics=["sys_utterance"], pub_topics=[Topic.DIALOG_END])
def print_sys_utterance(self, sys_utterance: str = None) -> dict():
"""
The message is simply printed to the console.
Args:
sys_utterance (str): The system utterance
Returns:
dict with entry dialog_end: True or False
Raises:
ValueError: if there is no system utterance to print
"""
if sys_utterance is not None and sys_utterance != "":
print("System: {}".format(sys_utterance))
else:
raise ValueError("There is no system utterance. Did you forget to call an NLG module before?")
return {Topic.DIALOG_END: 'bye' in sys_utterance}
__init__(self, domain=None)
special
¶
print_sys_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/hci/console.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
gui
¶
GUIServer (Service)
¶
Source code in adviser/services/hci/gui.py
class GUIServer(Service):
def __init__(self, logger=None):
super().__init__(domain="", identifier="GUIServer")
self.websocket = None
self.loopy_loop = asyncio.new_event_loop()
# open UI in webbrowser automatically
webui_path = f"file:///{os.path.join(os.path.realpath(''), 'tools', 'webui', 'chat.html')}"
print("WEBUI accessible at", webui_path)
webbrowser.open(webui_path)
@PublishSubscribe(pub_topics=['gen_user_utterance'])
def user_utterance(self, message = ""):
return {'gen_user_utterance': message}
@PublishSubscribe(sub_topics=['sys_utterance'])
def forward_sys_utterance(self, sys_utterance: str):
self.forward_message_to_react(message=sys_utterance, topic="sys_utterance")
def forward_message_to_react(self, message, topic: str):
asyncio.set_event_loop(self.loopy_loop)
if self.websocket:
self.websocket.write_message({"topic": topic, "msg": message})
__init__(self, logger=None)
special
¶
Source code in adviser/services/hci/gui.py
def __init__(self, logger=None):
super().__init__(domain="", identifier="GUIServer")
self.websocket = None
self.loopy_loop = asyncio.new_event_loop()
# open UI in webbrowser automatically
webui_path = f"file:///{os.path.join(os.path.realpath(''), 'tools', 'webui', 'chat.html')}"
print("WEBUI accessible at", webui_path)
webbrowser.open(webui_path)
forward_message_to_react(self, message, topic)
¶
forward_sys_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/hci/gui.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
user_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/hci/gui.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
speech
special
¶
SpeechInputDecoder
¶
SpeechInputDecoder (Service)
¶
Source code in adviser/services/hci/speech/SpeechInputDecoder.py
class SpeechInputDecoder(Service):
def __init__(self, domain: Domain = "", identifier=None, conversation_log_dir: str = None, use_cuda=False):
"""
Transforms spoken input from the user to text for further processing.
Args:
domain (Domain): Needed for Service, but has no meaning here
identifier (string): Needed for Service
conversation_log_dir (string): If this is provided, logfiles will be placed by this Service into the specified directory.
use_cuda (boolean): Whether or not to run the computations on a GPU
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
# load model
model_dir = os.path.join(get_root_dir(), "resources", "models", "speech", "multi_en_20190916")
self.model, conf = load_trained_model(os.path.join(model_dir, "model.bin"))
self.vocab = conf.char_list
# setup beam search
self.bs = BeamSearch(scorers=self.model.scorers(),
weights={"decoder": 1.0, "ctc": 0.0},
sos=self.model.sos,
eos=self.model.eos,
beam_size=4,
vocab_size=len(self.vocab),
pre_beam_score_key="decoder")
self.bs.__class__ = BatchBeamSearch
# choose hardware to run on
if use_cuda:
self.device = "cuda"
else:
self.device = "cpu"
self.model.to(self.device)
self.bs.to(self.device)
# change from training mode to eval mode
self.model.eval()
self.bs.eval()
# scale and offset for feature normalization
# follows https://github.com/kaldi-asr/kaldi/blob/33255ed224500f55c8387f1e4fa40e08b73ff48a/src/transform/cmvn.cc#L92-L111
norm = torch.load(os.path.join(model_dir, "cmvn.bin"))
count = norm[0][-1]
mean = norm[0][:-1] / count
var = (norm[1][:-1] / count) - mean * mean
self.scale = 1.0 / torch.sqrt(var)
self.offset = - (mean * self.scale)
@PublishSubscribe(sub_topics=["speech_features"], pub_topics=["gen_user_utterance"])
def features_to_text(self, speech_features):
"""
Turns features of the utterance into a string and returns the user utterance in form of text
Args:
speech_features (np.array): The features that the speech feature extraction module produces
Returns:
dict(string, string): The user utterance as text
"""
speech_in_features_normalized = torch.from_numpy(speech_features) * self.scale + self.offset
with torch.no_grad():
encoded = self.model.encode(speech_in_features_normalized.to(self.device))
result = self.bs.forward(encoded)
# We only consider the most probable hypothesis.
# Language Model could improve this, right now we don't use one.
# This might need some post-processing...
user_utterance = "".join(self.vocab[y] for y in result[0].yseq) \
.replace("▁", " ") \
.replace("<space>", " ") \
.replace("<eos>", "") \
.strip()
# write decoded text into logging directory
if self.conversation_log_dir is not None:
with open(os.path.join(self.conversation_log_dir, (str(np.math.floor(time.time())) + "_user.txt")),
"w") as convo_log:
convo_log.write(user_utterance)
print("User: {}\n".format(user_utterance))
return {'gen_user_utterance': user_utterance}
__init__(self, domain='', identifier=None, conversation_log_dir=None, use_cuda=False)
special
¶
Transforms spoken input from the user to text for further processing.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
Needed for Service, but has no meaning here |
'' |
identifier |
string |
Needed for Service |
None |
conversation_log_dir |
string |
If this is provided, logfiles will be placed by this Service into the specified directory. |
None |
use_cuda |
boolean |
Whether or not to run the computations on a GPU |
False |
Source code in adviser/services/hci/speech/SpeechInputDecoder.py
def __init__(self, domain: Domain = "", identifier=None, conversation_log_dir: str = None, use_cuda=False):
"""
Transforms spoken input from the user to text for further processing.
Args:
domain (Domain): Needed for Service, but has no meaning here
identifier (string): Needed for Service
conversation_log_dir (string): If this is provided, logfiles will be placed by this Service into the specified directory.
use_cuda (boolean): Whether or not to run the computations on a GPU
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
# load model
model_dir = os.path.join(get_root_dir(), "resources", "models", "speech", "multi_en_20190916")
self.model, conf = load_trained_model(os.path.join(model_dir, "model.bin"))
self.vocab = conf.char_list
# setup beam search
self.bs = BeamSearch(scorers=self.model.scorers(),
weights={"decoder": 1.0, "ctc": 0.0},
sos=self.model.sos,
eos=self.model.eos,
beam_size=4,
vocab_size=len(self.vocab),
pre_beam_score_key="decoder")
self.bs.__class__ = BatchBeamSearch
# choose hardware to run on
if use_cuda:
self.device = "cuda"
else:
self.device = "cpu"
self.model.to(self.device)
self.bs.to(self.device)
# change from training mode to eval mode
self.model.eval()
self.bs.eval()
# scale and offset for feature normalization
# follows https://github.com/kaldi-asr/kaldi/blob/33255ed224500f55c8387f1e4fa40e08b73ff48a/src/transform/cmvn.cc#L92-L111
norm = torch.load(os.path.join(model_dir, "cmvn.bin"))
count = norm[0][-1]
mean = norm[0][:-1] / count
var = (norm[1][:-1] / count) - mean * mean
self.scale = 1.0 / torch.sqrt(var)
self.offset = - (mean * self.scale)
features_to_text(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechInputDecoder.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
get_root_dir()
¶
SpeechInputFeatureExtractor
¶
SpeechInputFeatureExtractor (Service)
¶
Source code in adviser/services/hci/speech/SpeechInputFeatureExtractor.py
class SpeechInputFeatureExtractor(Service):
def __init__(self, domain: Domain = ""):
"""
Given a sound, this service extracts features and passes them on to the decoder for ASR
Args:
domain (Domain): Needed for Service, no meaning here
"""
Service.__init__(self, domain=domain)
@PublishSubscribe(sub_topics=["speech_in"], pub_topics=["speech_features"])
def speech_to_features(self, speech_in: Tuple[numpy.array, int]):
"""
Turns numpy array with utterance into features
Args:
speech_in (tuple(np.array), int): The utterance, represented as array and the sampling rate
Returns:
np.array: The extracted features of the utterance
"""
sample_frequence = speech_in[1]
speech_in = torch.from_numpy(speech_in[0]).unsqueeze(0)
filter_bank = torchaudio.compliance.kaldi.fbank(speech_in, num_mel_bins=80, sample_frequency=sample_frequence)
# Default ASR model uses 16kHz, but different models are possible, then the sampling rate only needs to be changd in the recorder
pitch = torch.zeros(filter_bank.shape[0], 3) # TODO: check if torchaudio pitch function is better
speech_in_features = torch.cat([filter_bank, pitch], 1).numpy()
return {'speech_features': speech_in_features}
@PublishSubscribe(sub_topics=["speech_in"], pub_topics=["mfcc"])
def speech_to_mfcc(self, speech_in):
"""
Extracts 13 Mel Frequency Cepstral Coefficients (MFCC) from input utterance.
Args:
speech_in (tuple(np.array), int): The utterance, represented as array and the sampling rate
Returns:
np.array: The extracted features of the utterance
"""
speech = torch.from_numpy(speech_in[0]).unsqueeze(0)
mfcc = torchaudio.compliance.kaldi.mfcc(
speech,
sample_frequency=speech_in[1]
)
return {'mfcc': mfcc}
@PublishSubscribe(sub_topics=["speech_in"], pub_topics=["fbank"])
def speech_to_fbank(self, speech_in):
"""
Extracts 23 filterbanks from input utterance.
Args:
speech_in (tuple(np.array), int): The utterance, represented as array and the sampling rate
Returns:
np.array: The extracted features of the utterance
"""
speech = torch.from_numpy(speech_in[0]).unsqueeze(0)
fbank = torchaudio.compliance.kaldi.fbank(
speech,
sample_frequency=speech_in[1]
)
return {'fbank': fbank}
__init__(self, domain='')
special
¶
Given a sound, this service extracts features and passes them on to the decoder for ASR
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
Needed for Service, no meaning here |
'' |
Source code in adviser/services/hci/speech/SpeechInputFeatureExtractor.py
speech_to_fbank(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechInputFeatureExtractor.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
speech_to_features(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechInputFeatureExtractor.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
speech_to_mfcc(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechInputFeatureExtractor.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
SpeechOutputGenerator
¶
SpeechOutputGenerator (Service)
¶
Source code in adviser/services/hci/speech/SpeechOutputGenerator.py
class SpeechOutputGenerator(Service):
def __init__(self, domain: Domain = "", identifier: str = None, use_cuda=False, sub_topic_domains: Dict[str, str] = {}):
"""
Text To Speech Module that reads out the system utterance.
Args:
domain (Domain): Needed for Service, no meaning here
identifier (string): Needed for Service
use_cuda (boolean): Whether or not to perform computations on GPU. Highly recommended if available
sub_topic_domains: see `services.service.Service` constructor for more details
"""
Service.__init__(self, domain=domain, identifier=identifier, sub_topic_domains=sub_topic_domains)
self.models_directory = os.path.join(get_root_dir(), "resources", "models", "speech")
# The following lines can be changed to incorporate different models.
# This is the only thing that needs to be changed for that, everything else should be dynamic.
self.transcription_type = "phn"
self.dict_path = os.path.join(self.models_directory,
"phn_train_no_dev_pytorch_train_fastspeech.v4", "data", "lang_1phn",
"train_no_dev_units.txt")
self.model_path = os.path.join(self.models_directory,
"phn_train_no_dev_pytorch_train_fastspeech.v4", "exp",
"phn_train_no_dev_pytorch_train_fastspeech.v4", "results",
"model.last1.avg.best")
self.vocoder_path = os.path.join(self.models_directory,
"ljspeech.parallel_wavegan.v1", "checkpoint-400000steps.pkl")
self.vocoder_conf = os.path.join(self.models_directory, "ljspeech.parallel_wavegan.v1", "config.yml")
# define device to run the synthesis on
if use_cuda:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# define end to end TTS model
self.input_dimensions, self.output_dimensions, self.train_args = get_model_conf(self.model_path)
model_class = dynamic_import.dynamic_import(self.train_args.model_module)
model = model_class(self.input_dimensions, self.output_dimensions, self.train_args)
torch_load(self.model_path, model)
self.model = model.eval().to(self.device)
self.inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})
# define neural vocoder
with open(self.vocoder_conf) as vocoder_config_file:
self.config = yaml.load(vocoder_config_file, Loader=yaml.Loader)
vocoder = ParallelWaveGANGenerator(**self.config["generator_params"])
vocoder.load_state_dict(torch.load(self.vocoder_path, map_location="cpu")["model"]["generator"])
vocoder.remove_weight_norm()
self.vocoder = vocoder.eval().to(self.device)
with open(self.dict_path) as dictionary_file:
lines = dictionary_file.readlines()
lines = [line.replace("\n", "").split(" ") for line in lines]
self.char_to_id = {c: int(i) for c, i in lines}
self.g2p = G2p()
# download the pretrained Punkt tokenizer from NLTK. This is done only
# the first time the code is executed on a machine, if it has been done
# before, this line will be skipped and output a warning. We will probably
# redirect warnings into a file rather than std_err in the future, since
# there's also a lot of pytorch warnings going on etc.
nltk.download('punkt', quiet=True)
def preprocess_text_input(self, text):
"""
Clean the text and then convert it to id sequence.
Args:
text (string): The text to preprocess
"""
text = custom_english_cleaners(text) # cleans the text
if self.transcription_type == "phn": # depending on the model type, different preprocessing is needed.
text = filter(lambda s: s != " ", self.g2p(text))
text = " ".join(text)
char_sequence = text.split(" ")
else:
char_sequence = list(text)
id_sequence = []
for c in char_sequence:
if c.isspace():
id_sequence += [self.char_to_id["<space>"]]
elif c not in self.char_to_id.keys():
id_sequence += [self.char_to_id["<unk>"]]
else:
id_sequence += [self.char_to_id[c]]
id_sequence += [self.input_dimensions - 1] # <eos>
return torch.LongTensor(id_sequence).view(-1).to(self.device)
@PublishSubscribe(sub_topics=["sys_utterance"], pub_topics=["system_speech"])
def generate_speech(self, sys_utterance):
"""
Takes the system utterance and turns it into a sound
Args:
sys_utterance (string): The new system utterance
Returns:
dict(string, tuple(np.array, int, string)): Everything needed to play the system utterance as an audio and the utterance in text for logging
"""
with torch.no_grad():
preprocessed_text_as_list = self.preprocess_text_input(sys_utterance)
features_from_text, _, _ = self.model.inference(preprocessed_text_as_list, self.inference_args)
feature_dimension = features_from_text.size(0) * self.config["hop_size"]
random_tensor_with_proper_dimensions = torch.randn(1, 1, feature_dimension).to(self.device)
auxiliary_content_window = self.config["generator_params"]["aux_context_window"]
preprocessed_features = features_from_text.unsqueeze(0).transpose(2, 1)
features_from_text = torch.nn.ReplicationPad1d(auxiliary_content_window)(preprocessed_features)
generated_speech = self.vocoder(random_tensor_with_proper_dimensions, features_from_text).view(-1)
sound_as_array = generated_speech.view(-1).cpu().numpy()
return {"system_speech": (sound_as_array, self.config["sampling_rate"], sys_utterance)}
__init__(self, domain='', identifier=None, use_cuda=False, sub_topic_domains={})
special
¶
Text To Speech Module that reads out the system utterance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
Needed for Service, no meaning here |
'' |
identifier |
string |
Needed for Service |
None |
use_cuda |
boolean |
Whether or not to perform computations on GPU. Highly recommended if available |
False |
sub_topic_domains |
Dict[str, str] |
see |
{} |
Source code in adviser/services/hci/speech/SpeechOutputGenerator.py
def __init__(self, domain: Domain = "", identifier: str = None, use_cuda=False, sub_topic_domains: Dict[str, str] = {}):
"""
Text To Speech Module that reads out the system utterance.
Args:
domain (Domain): Needed for Service, no meaning here
identifier (string): Needed for Service
use_cuda (boolean): Whether or not to perform computations on GPU. Highly recommended if available
sub_topic_domains: see `services.service.Service` constructor for more details
"""
Service.__init__(self, domain=domain, identifier=identifier, sub_topic_domains=sub_topic_domains)
self.models_directory = os.path.join(get_root_dir(), "resources", "models", "speech")
# The following lines can be changed to incorporate different models.
# This is the only thing that needs to be changed for that, everything else should be dynamic.
self.transcription_type = "phn"
self.dict_path = os.path.join(self.models_directory,
"phn_train_no_dev_pytorch_train_fastspeech.v4", "data", "lang_1phn",
"train_no_dev_units.txt")
self.model_path = os.path.join(self.models_directory,
"phn_train_no_dev_pytorch_train_fastspeech.v4", "exp",
"phn_train_no_dev_pytorch_train_fastspeech.v4", "results",
"model.last1.avg.best")
self.vocoder_path = os.path.join(self.models_directory,
"ljspeech.parallel_wavegan.v1", "checkpoint-400000steps.pkl")
self.vocoder_conf = os.path.join(self.models_directory, "ljspeech.parallel_wavegan.v1", "config.yml")
# define device to run the synthesis on
if use_cuda:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# define end to end TTS model
self.input_dimensions, self.output_dimensions, self.train_args = get_model_conf(self.model_path)
model_class = dynamic_import.dynamic_import(self.train_args.model_module)
model = model_class(self.input_dimensions, self.output_dimensions, self.train_args)
torch_load(self.model_path, model)
self.model = model.eval().to(self.device)
self.inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})
# define neural vocoder
with open(self.vocoder_conf) as vocoder_config_file:
self.config = yaml.load(vocoder_config_file, Loader=yaml.Loader)
vocoder = ParallelWaveGANGenerator(**self.config["generator_params"])
vocoder.load_state_dict(torch.load(self.vocoder_path, map_location="cpu")["model"]["generator"])
vocoder.remove_weight_norm()
self.vocoder = vocoder.eval().to(self.device)
with open(self.dict_path) as dictionary_file:
lines = dictionary_file.readlines()
lines = [line.replace("\n", "").split(" ") for line in lines]
self.char_to_id = {c: int(i) for c, i in lines}
self.g2p = G2p()
# download the pretrained Punkt tokenizer from NLTK. This is done only
# the first time the code is executed on a machine, if it has been done
# before, this line will be skipped and output a warning. We will probably
# redirect warnings into a file rather than std_err in the future, since
# there's also a lot of pytorch warnings going on etc.
nltk.download('punkt', quiet=True)
generate_speech(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechOutputGenerator.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
preprocess_text_input(self, text)
¶
Clean the text and then convert it to id sequence.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
string |
The text to preprocess |
required |
Source code in adviser/services/hci/speech/SpeechOutputGenerator.py
def preprocess_text_input(self, text):
"""
Clean the text and then convert it to id sequence.
Args:
text (string): The text to preprocess
"""
text = custom_english_cleaners(text) # cleans the text
if self.transcription_type == "phn": # depending on the model type, different preprocessing is needed.
text = filter(lambda s: s != " ", self.g2p(text))
text = " ".join(text)
char_sequence = text.split(" ")
else:
char_sequence = list(text)
id_sequence = []
for c in char_sequence:
if c.isspace():
id_sequence += [self.char_to_id["<space>"]]
elif c not in self.char_to_id.keys():
id_sequence += [self.char_to_id["<unk>"]]
else:
id_sequence += [self.char_to_id[c]]
id_sequence += [self.input_dimensions - 1] # <eos>
return torch.LongTensor(id_sequence).view(-1).to(self.device)
get_root_dir()
¶
SpeechOutputPlayer
¶
SpeechOutputPlayer (Service)
¶
Source code in adviser/services/hci/speech/SpeechOutputPlayer.py
class SpeechOutputPlayer(Service):
def __init__(self, domain: Domain = "", conversation_log_dir: str = None, identifier: str = None):
"""
Service that plays the system utterance as sound
Args:
domain (Domain): Needed for Service, but has no meaning here
conversation_log_dir (string): If this is provided it will create log files in the specified directory.
identifier (string): Needed for Service.
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
self.interaction_count = 0
@PublishSubscribe(sub_topics=["system_speech"], pub_topics=[])
def speak(self, system_speech):
"""
Takes the system utterance and reads it out. Also can log the audio and text.
Args:
system_speech (np.array): An array of audio that is meant to produce a sound from. The result of the systems TTS synthesis service.
"""
sounddevice.play(system_speech[0], system_speech[1])
# log the utterance
if self.conversation_log_dir is not None:
file_path = os.path.join(self.conversation_log_dir, (str(math.floor(time.time()))))
sf.write(file_path + "_system.wav", system_speech[0], system_speech[1], 'PCM_24')
with open(file_path + "_system.txt", "w") as convo_log:
convo_log.write(system_speech[2])
__init__(self, domain='', conversation_log_dir=None, identifier=None)
special
¶
Service that plays the system utterance as sound
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
Needed for Service, but has no meaning here |
'' |
conversation_log_dir |
string |
If this is provided it will create log files in the specified directory. |
None |
identifier |
string |
Needed for Service. |
None |
Source code in adviser/services/hci/speech/SpeechOutputPlayer.py
def __init__(self, domain: Domain = "", conversation_log_dir: str = None, identifier: str = None):
"""
Service that plays the system utterance as sound
Args:
domain (Domain): Needed for Service, but has no meaning here
conversation_log_dir (string): If this is provided it will create log files in the specified directory.
identifier (string): Needed for Service.
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
self.interaction_count = 0
speak(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechOutputPlayer.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
SpeechRecorder
¶
SpeechRecorder (Service)
¶
Source code in adviser/services/hci/speech/SpeechRecorder.py
class SpeechRecorder(Service):
def __init__(self, domain: Union[str, Domain] = "", conversation_log_dir: str = None, enable_plotting: bool = False, threshold: int = 8000,
voice_privacy: bool = False, identifier: str = None) -> None:
"""
A service that can record a microphone upon a key pressing event
and publish the result as an array. The end of the utterance is
detected automatically, also the voice can be masked to alleviate
privacy issues.
Args:
domain (Domain): I don't know why this is here. Service needs it, but it means nothing in this context.
conversation_log_dir (string): If this parameter is given, log files of the conversation will be created in this directory
enable_plotting (boolean): If this is set to True, the recorder is no longer real time able and thus the recordings don't work properly. This is just to be used to tune the threshold for the end of utterance detection, not during deployment.
threshold (int): The threshold below which the assumption of the end of utterance detection is silence
voice_privacy (boolean): Whether or not to enable the masking of the users voice
identifier (string): I don't know why this is here. Service needs it.
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
self.recording_indicator = False
self.audio_interface = pyaudio.PyAudio()
self.push_to_talk_listener = keyboard.Listener(on_press=self.start_recording)
self.threshold = threshold
self.enable_plotting = enable_plotting
self.voice_privacy = voice_privacy
@PublishSubscribe(pub_topics=["speech_in"])
def record_user_utterance(self):
"""
Records audio once a button is pressed and stops if there is enough continuous silence.
The numpy array consisting of the frames will be published once it's done.
Returns:
dict(string, tuple(np.array, int)): The utterance in form of an array and the sampling rate of the utterance
"""
self.recording_indicator = True
chunk = 1024 # how many frames per chunk
audio_format = pyaudio.paInt16 # 16 bit integer based audio for quick processing
channels = 1 # our asr model only accepts mono sounds
sampling_rate = 16000 # only 16000 Hz works for the asr model we're using
stream = self.audio_interface.open(format=audio_format,
channels=channels,
rate=sampling_rate,
input=True,
frames_per_buffer=chunk)
binary_sequence = [] # this will hold the entire utterance once it's finished as binary data
# setup for naive end of utterance detection
continuous_seconds_of_silence_before_utterance_ends = 3.0 # this may be changed freely
required_silence_length_to_stop_in_chunks = int(
(continuous_seconds_of_silence_before_utterance_ends * sampling_rate) / chunk)
reset = int((continuous_seconds_of_silence_before_utterance_ends * sampling_rate) / chunk)
maximum_utterance_time_in_chunks = int((20 * sampling_rate) / chunk) # 20 seconds
if self.enable_plotting:
threshold_plotter = self.threshold_plotter_generator()
chunks_recorded = 0
print("\nrecording...")
for _ in range(maximum_utterance_time_in_chunks):
raw_data = stream.read(chunk)
chunks_recorded += 1
wave_data = wave.struct.unpack("%dh" % chunk, raw_data)
binary_sequence.append(raw_data)
if self.enable_plotting:
threshold_plotter(wave_data)
if np.max(wave_data) > self.threshold:
required_silence_length_to_stop_in_chunks = reset
else:
required_silence_length_to_stop_in_chunks -= 1
if required_silence_length_to_stop_in_chunks == 0:
break
print("...done recording.\n")
stream.stop_stream()
stream.close()
if self.enable_plotting:
plt.close()
if self.conversation_log_dir is not None:
audio_file = wave.open(
os.path.join(self.conversation_log_dir, (str(np.math.floor(time.time())) + "_user.wav")), 'wb')
audio_file.setnchannels(channels)
audio_file.setsampwidth(self.audio_interface.get_sample_size(audio_format))
audio_file.setframerate(sampling_rate)
audio_file.writeframes(b''.join(binary_sequence))
audio_file.close()
self.recording_indicator = False
audio_sequence = wave.struct.unpack("%dh" % chunk * chunks_recorded, b''.join(binary_sequence))
if self.voice_privacy:
return {"speech_in": (voice_sanitizer(np.array(audio_sequence, dtype=np.float32)), sampling_rate)}
else:
return {"speech_in": (np.array(audio_sequence, dtype=np.float32), sampling_rate)}
def start_recording(self, key):
"""
This method is a callback of the push to talk key
listener. It calls the recorder, if it's not already recording.
Args:
key (Key): The pressed key
"""
if (key is keyboard.Key.cmd_r or key is keyboard.Key.ctrl_r) and not self.recording_indicator:
self.record_user_utterance()
def start_recorder(self):
"""
Starts the listener and outputs that the speech recorder is ready for use
"""
self.push_to_talk_listener.start()
print("To speak to the system, tap your right [CTRL] or [CMD] key.\n"
"It will try to automatically detect when your utterance is over.\n")
def threshold_plotter_generator(self):
"""
Generates a plotter to visualize when the signal is above the set threshold
Returns:
function: Plots the threshold with the current continuous waveform
"""
import matplotlib
matplotlib.use('TkAgg')
plt.figure(figsize=(10, 2))
plt.axhline(y=self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.axhline(y=-self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.pause(0.000000000001)
def threshold_plotter(data):
plt.clf()
plt.tight_layout()
plt.axis([0, len(data), -20000, 20000])
plt.plot(data, color='b')
plt.axhline(y=self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.axhline(y=-self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.pause(0.000000000001)
return threshold_plotter
__init__(self, domain='', conversation_log_dir=None, enable_plotting=False, threshold=8000, voice_privacy=False, identifier=None)
special
¶
A service that can record a microphone upon a key pressing event and publish the result as an array. The end of the utterance is detected automatically, also the voice can be masked to alleviate privacy issues.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
I don't know why this is here. Service needs it, but it means nothing in this context. |
'' |
conversation_log_dir |
string |
If this parameter is given, log files of the conversation will be created in this directory |
None |
enable_plotting |
boolean |
If this is set to True, the recorder is no longer real time able and thus the recordings don't work properly. This is just to be used to tune the threshold for the end of utterance detection, not during deployment. |
False |
threshold |
int |
The threshold below which the assumption of the end of utterance detection is silence |
8000 |
voice_privacy |
boolean |
Whether or not to enable the masking of the users voice |
False |
identifier |
string |
I don't know why this is here. Service needs it. |
None |
Source code in adviser/services/hci/speech/SpeechRecorder.py
def __init__(self, domain: Union[str, Domain] = "", conversation_log_dir: str = None, enable_plotting: bool = False, threshold: int = 8000,
voice_privacy: bool = False, identifier: str = None) -> None:
"""
A service that can record a microphone upon a key pressing event
and publish the result as an array. The end of the utterance is
detected automatically, also the voice can be masked to alleviate
privacy issues.
Args:
domain (Domain): I don't know why this is here. Service needs it, but it means nothing in this context.
conversation_log_dir (string): If this parameter is given, log files of the conversation will be created in this directory
enable_plotting (boolean): If this is set to True, the recorder is no longer real time able and thus the recordings don't work properly. This is just to be used to tune the threshold for the end of utterance detection, not during deployment.
threshold (int): The threshold below which the assumption of the end of utterance detection is silence
voice_privacy (boolean): Whether or not to enable the masking of the users voice
identifier (string): I don't know why this is here. Service needs it.
"""
Service.__init__(self, domain=domain, identifier=identifier)
self.conversation_log_dir = conversation_log_dir
self.recording_indicator = False
self.audio_interface = pyaudio.PyAudio()
self.push_to_talk_listener = keyboard.Listener(on_press=self.start_recording)
self.threshold = threshold
self.enable_plotting = enable_plotting
self.voice_privacy = voice_privacy
record_user_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/hci/speech/SpeechRecorder.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
start_recorder(self)
¶
Starts the listener and outputs that the speech recorder is ready for use
Source code in adviser/services/hci/speech/SpeechRecorder.py
start_recording(self, key)
¶
This method is a callback of the push to talk key listener. It calls the recorder, if it's not already recording.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
Key |
The pressed key |
required |
Source code in adviser/services/hci/speech/SpeechRecorder.py
def start_recording(self, key):
"""
This method is a callback of the push to talk key
listener. It calls the recorder, if it's not already recording.
Args:
key (Key): The pressed key
"""
if (key is keyboard.Key.cmd_r or key is keyboard.Key.ctrl_r) and not self.recording_indicator:
self.record_user_utterance()
threshold_plotter_generator(self)
¶
Generates a plotter to visualize when the signal is above the set threshold
Returns:
Type | Description |
---|---|
function |
Plots the threshold with the current continuous waveform |
Source code in adviser/services/hci/speech/SpeechRecorder.py
def threshold_plotter_generator(self):
"""
Generates a plotter to visualize when the signal is above the set threshold
Returns:
function: Plots the threshold with the current continuous waveform
"""
import matplotlib
matplotlib.use('TkAgg')
plt.figure(figsize=(10, 2))
plt.axhline(y=self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.axhline(y=-self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.pause(0.000000000001)
def threshold_plotter(data):
plt.clf()
plt.tight_layout()
plt.axis([0, len(data), -20000, 20000])
plt.plot(data, color='b')
plt.axhline(y=self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.axhline(y=-self.threshold, xmin=0.0, xmax=1.0, color='r')
plt.pause(0.000000000001)
return threshold_plotter
voice_sanitizer(audio)
¶
While this is by no means a good voice sanitizer, it works as a proof of concept. It randomly shifts the spectrogram of a speakers utterance up or down, making automatic speaker identification much harder while keeping impact on asr performance as low as possible. The use should be turned off by default.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
audio |
np.array |
The audio represented as array |
required |
Returns:
Type | Description |
---|---|
np.array |
The mutated audio as array |
Source code in adviser/services/hci/speech/SpeechRecorder.py
def voice_sanitizer(audio):
"""
While this is by no means a good voice sanitizer,
it works as a proof of concept. It randomly shifts
the spectrogram of a speakers utterance up or down,
making automatic speaker identification much harder
while keeping impact on asr performance as low as
possible. The use should be turned off by default.
Args:
audio (np.array): The audio represented as array
Returns:
np.array: The mutated audio as array
"""
spectrogram = librosa.stft(audio)
voice_shift = np.random.randint(3, 6)
if np.random.choice([True, False]):
for frequency_index, _ in enumerate(spectrogram):
# mutate the voice to be higher
try:
spectrogram[len(spectrogram) - (frequency_index + 1)] = spectrogram[
len(spectrogram) - (frequency_index + 1 + voice_shift)]
except IndexError:
pass
else:
for frequency_index, _ in enumerate(spectrogram):
# mutate the voice to be lower
try:
spectrogram[frequency_index] = spectrogram[frequency_index + voice_shift]
except IndexError:
pass
return librosa.istft(spectrogram)
cleaners
¶
This file is derived from https://github.com/keithito/tacotron.
basic_cleaners(text)
¶
Basic pipeline that lowercases and collapses whitespace without transliteration.
collapse_whitespace(text)
¶
convert_to_ascii(text)
¶
custom_english_cleaners(text)
¶
Custom pipeline for English text, including number and abbreviation expansion.
Source code in adviser/services/hci/speech/cleaners.py
def custom_english_cleaners(text):
"""Custom pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = expand_email(text)
text = expand_acronym(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = expand_symbols(text)
text = remove_unnecessary_symbols(text)
text = uppercase(text)
text = collapse_whitespace(text)
return text
english_cleaners(text)
¶
Pipeline for English text, including number and abbreviation expansion.
Source code in adviser/services/hci/speech/cleaners.py
expand_abbreviations(text)
¶
Preprocesses a text to turn abbreviations into forms that the TTS can pronounce properly
text (string): Text to be preprocessed
Source code in adviser/services/hci/speech/cleaners.py
expand_acronym(text)
¶
Preprocesses a text to turn acronyms into forms that the TTS can pronounce properly
text (string): Text to be preprocessed
expand_email(text)
¶
expand_numbers(text)
¶
expand_symbols(text)
¶
lowercase(text)
¶
normalize_numbers(text)
¶
Normalizes numbers in an utterance as preparation for TTS
text (string): Text to be preprocessed
Source code in adviser/services/hci/speech/cleaners.py
def normalize_numbers(text):
"""
Normalizes numbers in an utterance as preparation for TTS
text (string): Text to be preprocessed
"""
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_ID_number_re, _expand_ID_number, text)
text = re.sub(_number_re, _expand_number, text)
return text
remove_unnecessary_symbols(text)
¶
transliteration_cleaners(text)
¶
Pipeline for non-English text that transliterates to ASCII.
uppercase(text)
¶
speech_utility
¶
Utility for the emotion recognition script that needs the utterance a s file
delete_file(filepath)
¶
Deletes the file at the given path to clean up the audio file once it's not needed anymore. This is why unique filenames are important.
filepath (string): path to the file that is to be deleted
Source code in adviser/services/hci/speech/speech_utility.py
def delete_file(filepath):
"""
Deletes the file at the given path to clean up the audio file
once it's not needed anymore. This is why unique filenames are
important.
filepath (string): path to the file that is to be deleted
"""
if os.path.exists(filepath):
os.remove(filepath)
else:
print("The file cannot be deleted, as it was not found. "
"Please check the provided path for errors: \n{}".format(filepath))
sound_array_to_file(filepath, sampling_rate, sound_as_array)
¶
Saves the recording of the recorder to a file
Turns the audio from the recorder service into a wav file for processing with opensmile c++ scripts
filepath (string): full path, including filename and .wav suffix at an arbitrary location. Careful: python takes paths as relative to the main script. The name should be unique, to ensure files don't get mixed up if there are multiple calls in short time and one file might get overwriteen or deleted before it's done being processed. sampling_rate (int): the sampling rate of the audio, as published by the recorder sound_as_array (np.array): the audio in form of an array as published by the recorder
Source code in adviser/services/hci/speech/speech_utility.py
def sound_array_to_file(filepath, sampling_rate, sound_as_array):
"""
Saves the recording of the recorder to a file
Turns the audio from the recorder service into a wav file for
processing with opensmile c++ scripts
filepath (string): full path, including filename and .wav suffix
at an arbitrary location. Careful: python takes paths as
relative to the main script. The name should be unique, to
ensure files don't get mixed up if there are multiple calls
in short time and one file might get overwriteen or deleted
before it's done being processed.
sampling_rate (int): the sampling rate of the audio, as
published by the recorder
sound_as_array (np.array): the audio in form of an array as
published by the recorder
"""
librosa.output.write_wav(filepath, sound_as_array, sampling_rate)
video
special
¶
FeatureExtractor
¶
Feature extraction with openSMILE
VideoFeatureExtractor (Service)
¶
TODO
Source code in adviser/services/hci/video/FeatureExtractor.py
class VideoFeatureExtractor(Service):
"""TODO"""
def __init__(self, domain: Domain = ""):
Service.__init__(self, domain=domain)
self.module_dir = os.path.dirname(os.path.abspath(__file__))
# # CLAHE (Contrast Limited Adaptive Histogram Equalization)
self.CLAHE = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
# for detecting faces (returns coordinates of rectangle(s) of face area(s))
self.DETECTOR = dlib.get_frontal_face_detector()
# facial landmark predictor
predictor_file = os.path.abspath(os.path.join(self.module_dir, '..', '..', '..', 'resources', 'models', 'video', 'shape_predictor_68_face_landmarks.dat'))
self.PREDICTOR = dlib.shape_predictor(predictor_file)
@PublishSubscribe(queued_sub_topics=["video_input"], sub_topics=["user_acts"],
pub_topics=["fl_features"])
def extract_fl_features(self, video_input, user_acts):
"""TODO
Returns:
dict: TODO
"""
def _distance(a, b):
return np.linalg.norm(a-b)
print(f'VIDEO FEATURE ENTER, len(video_input): {len(video_input)}')
features = []
aggregated_feats = None
for frame in video_input[::2]:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = self.CLAHE.apply(frame)
faces = self.DETECTOR(frame, 1)
if len(faces) > 0: # at least one face detected
landmarks = self.PREDICTOR(frame, faces[0])
landmarks = face_utils.shape_to_np(landmarks)
norm_left_eye = _distance(landmarks[21], landmarks[39])
norm_right_eye = _distance(landmarks[22], landmarks[42])
norm_lips = _distance(landmarks[33], landmarks[52])
eyebrow_left = sum(
[(_distance(landmarks[39], landmarks[i]) / norm_left_eye)
for i in [18, 19, 20, 21]]
)
eyebrow_right = sum(
[(_distance(landmarks[42], landmarks[i]) / norm_right_eye)
for i in [22, 23, 24, 25]]
)
lip_left = sum(
[(_distance(landmarks[33], landmarks[i]) / norm_lips)
for i in [48, 49, 50]]
)
lip_right = sum(
[(_distance(landmarks[33], landmarks[i]) / norm_lips)
for i in [52, 53, 54]]
)
mouth_width = _distance(landmarks[48], landmarks[54])
mouth_height = _distance(landmarks[51], landmarks[57])
features.append(np.array([
eyebrow_left,
eyebrow_right,
lip_left,
lip_right,
mouth_width,
mouth_height
]))
# aggregate features across frames
if len(features) > 0:
mean = np.mean(features, axis=0)
mini = np.amin(features, axis=0)
maxi = np.amax(features, axis=0)
std = np.std(features, axis=0)
perc25 = np.percentile(features, q=25, axis=0)
perc75 = np.percentile(features, q=75, axis=0)
aggregated_feats = np.array([mean, mini, maxi, std, perc25, perc75]).reshape(1, 36)
print("VIDEO FEAT PUB")
return {'fl_features': aggregated_feats}
__init__(self, domain='')
special
¶
Source code in adviser/services/hci/video/FeatureExtractor.py
def __init__(self, domain: Domain = ""):
Service.__init__(self, domain=domain)
self.module_dir = os.path.dirname(os.path.abspath(__file__))
# # CLAHE (Contrast Limited Adaptive Histogram Equalization)
self.CLAHE = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
# for detecting faces (returns coordinates of rectangle(s) of face area(s))
self.DETECTOR = dlib.get_frontal_face_detector()
# facial landmark predictor
predictor_file = os.path.abspath(os.path.join(self.module_dir, '..', '..', '..', 'resources', 'models', 'video', 'shape_predictor_68_face_landmarks.dat'))
self.PREDICTOR = dlib.shape_predictor(predictor_file)
extract_fl_features(self, *args, **kwargs)
¶
Source code in adviser/services/hci/video/FeatureExtractor.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
VideoInput
¶
VideoInput (Service)
¶
Captures frames with a specified capture interval between two consecutive dialog turns and returns a list of frames.
Source code in adviser/services/hci/video/VideoInput.py
class VideoInput(Service):
"""
Captures frames with a specified capture interval between two consecutive dialog turns and returns a list of frames.
"""
def __init__(self, domain=None, camera_id: int = 0, capture_interval: int = 10e5, identifier: str = None):
"""
Args:
camera_id (int): device id (if only 1 camera device is connected, id is 0, if two are connected choose between 0 and 1, ...)
capture_interval (int): try to capture a frame every x microseconds - is a lower bound, no hard time guarantees (e.g. 5e5 -> every >= 0.5 seconds)
"""
Service.__init__(self, domain, identifier=identifier)
self.cap = cv2.VideoCapture(camera_id) # get handle to camera device
if not self.cap.isOpened():
self.cap.open() # open
self.terminating = Event()
self.terminating.clear()
self.capture_thread = Thread(target=self.capture) # create thread object for capturing
self.capture_interval = capture_interval
def capture(self):
"""
Continuous video capture, meant to be run in a loop.
Calls `publish_img` once per interval tick to publish the captured image.
"""
while self.cap.isOpened() and not self.terminating.isSet():
start_time = datetime.datetime.now()
# Capture frame-by-frame
# cap.read() returns a bool (true when frame was read correctly)
ret, frame = self.cap.read()
# Our operations on the frame come here
if ret:
rgb_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.publish_img(rgb_img=rgb_img)
end_time = datetime.datetime.now()
time_diff = end_time - start_time
wait_seconds = (self.capture_interval - time_diff.microseconds)*1e-6 # note: time to wait for next capture to match specified sampling rate in seconds
if wait_seconds > 0.0:
time.sleep(wait_seconds)
if self.cap.isOpened():
self.cap.release()
def dialog_end(self):
self.terminating.set()
def dialog_start(self):
if not self.capture_thread.is_alive():
print("Starting video capture...")
self.capture_thread.start()
@PublishSubscribe(pub_topics=['video_input'])
def publish_img(self, rgb_img) -> dict(video_input=List[object]):
"""
Helper function to publish images from a loop.
"""
return {'video_input': rgb_img} # NOTE: in the future, copy frames for more safety (capturing thread may overwrite them)
__init__(self, domain=None, camera_id=0, capture_interval=1000000.0, identifier=None)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
camera_id |
int |
device id (if only 1 camera device is connected, id is 0, if two are connected choose between 0 and 1, ...) |
0 |
capture_interval |
int |
try to capture a frame every x microseconds - is a lower bound, no hard time guarantees (e.g. 5e5 -> every >= 0.5 seconds) |
1000000.0 |
Source code in adviser/services/hci/video/VideoInput.py
def __init__(self, domain=None, camera_id: int = 0, capture_interval: int = 10e5, identifier: str = None):
"""
Args:
camera_id (int): device id (if only 1 camera device is connected, id is 0, if two are connected choose between 0 and 1, ...)
capture_interval (int): try to capture a frame every x microseconds - is a lower bound, no hard time guarantees (e.g. 5e5 -> every >= 0.5 seconds)
"""
Service.__init__(self, domain, identifier=identifier)
self.cap = cv2.VideoCapture(camera_id) # get handle to camera device
if not self.cap.isOpened():
self.cap.open() # open
self.terminating = Event()
self.terminating.clear()
self.capture_thread = Thread(target=self.capture) # create thread object for capturing
self.capture_interval = capture_interval
capture(self)
¶
Continuous video capture, meant to be run in a loop.
Calls publish_img
once per interval tick to publish the captured image.
Source code in adviser/services/hci/video/VideoInput.py
def capture(self):
"""
Continuous video capture, meant to be run in a loop.
Calls `publish_img` once per interval tick to publish the captured image.
"""
while self.cap.isOpened() and not self.terminating.isSet():
start_time = datetime.datetime.now()
# Capture frame-by-frame
# cap.read() returns a bool (true when frame was read correctly)
ret, frame = self.cap.read()
# Our operations on the frame come here
if ret:
rgb_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
self.publish_img(rgb_img=rgb_img)
end_time = datetime.datetime.now()
time_diff = end_time - start_time
wait_seconds = (self.capture_interval - time_diff.microseconds)*1e-6 # note: time to wait for next capture to match specified sampling rate in seconds
if wait_seconds > 0.0:
time.sleep(wait_seconds)
if self.cap.isOpened():
self.cap.release()
dialog_end(self)
¶
dialog_start(self)
¶
This function is called before the first message to a new dialog is published. You should overwrite this function to set/reset dialog-level variables.
publish_img(self, *args, **kwargs)
¶
Source code in adviser/services/hci/video/VideoInput.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
nlg
special
¶
__all__
special
¶
affective_nlg
¶
Handcrafted (i.e. template-based) Natural Language Generation Module
HandcraftedEmotionNLG (HandcraftedNLG)
¶
A child of the HandcraftedNLG, the HandcraftedEmotionNLG can choose between multiple affective response templates for each sys_act dependingon the current sys_emotion
Source code in adviser/services/nlg/affective_nlg.py
class HandcraftedEmotionNLG(HandcraftedNLG):
"""
A child of the HandcraftedNLG, the HandcraftedEmotionNLG can choose between multiple affective
response templates for each sys_act dependingon the current sys_emotion
"""
def __init__(self, domain: Domain, sub_topic_domains={}, template_file: str = None,
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
emotions: List[str] = [], debug_logger = None):
"""Constructor mainly extracts methods and rules from the template file"""
Service.__init__(self, domain=domain, sub_topic_domains=sub_topic_domains, debug_logger=debug_logger)
self.domain = domain
self.template_filename = template_file
self.templates = {}
self.logger = logger
self.emotions = emotions
self._initialise_templates()
@PublishSubscribe(sub_topics=["sys_act", "sys_emotion", "sys_engagement"], pub_topics=["sys_utterance"])
def generate_system_utterance(self, sys_act: SysAct = None, sys_emotion: str = None,
sys_engagement: str = None) -> dict(sys_utterance=str):
"""
Takes a system act, system emotion choice, and system engagement level choice, then
searches for a fitting rule, applies it and returns the message.
Args:
sys_act (SysAct): The system act, to check whether the dialogue was finished
sys_emotion (str): A string representing the system's choice of emotional response
sys_engagement (str): A string representing how engaged the system thinks the user is
Returns:
dict: a dict containing the system utterance
"""
rule_found = True
message = ""
try:
message = self.templates[sys_emotion].create_message(sys_act)
except BaseException as error:
rule_found = False
self.logger.error(error)
raise(error)
# inform if no applicable rule could be found in the template file
if not rule_found:
self.logger.info('Could not find a fitting rule for the given system act!')
self.logger.info("System Action: " + str(sys_act.type)
+ " - Slots: " + str(sys_act.slot_values))
# self.logger.dialog_turn("System Action: " + message)
return {'sys_utterance': message}
def _initialise_templates(self):
"""
Loads the correct template file based on which language has been selected
this should only be called on the first turn of the dialog
Args:
language (Language): Enum representing the language the user has selected
"""
for emotion in self.emotions:
self.templates[emotion.lower()] = TemplateFile(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
f'../../resources/nlg_templates/{self.domain.get_domain_name()}Messages{emotion}.nlg'),
self.domain)
self.templates["neutral"] = TemplateFile(os.path.join(
os.path.dirname(os.path.abspath(__file__)),
f'../../resources/nlg_templates/{self.domain.get_domain_name()}Messages.nlg'),
self.domain)
self._add_additional_methods_for_template_file()
def _add_additional_methods_for_template_file(self):
"""add the function prefixed by "_template_" to the template file interpreter"""
for (method_name, method) in inspect.getmembers(type(self), inspect.isfunction):
if method_name.startswith('_template_'):
for emotion in self.templates:
self.templates[emotion].add_python_function(method_name[10:], method, [self])
__init__(self, domain, sub_topic_domains={}, template_file=None, logger=<DiasysLogger adviser (NOTSET)>, template_file_german=None, emotions=[], debug_logger=None)
special
¶
Source code in adviser/services/nlg/affective_nlg.py
def __init__(self, domain: Domain, sub_topic_domains={}, template_file: str = None,
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
emotions: List[str] = [], debug_logger = None):
"""Constructor mainly extracts methods and rules from the template file"""
Service.__init__(self, domain=domain, sub_topic_domains=sub_topic_domains, debug_logger=debug_logger)
self.domain = domain
self.template_filename = template_file
self.templates = {}
self.logger = logger
self.emotions = emotions
self._initialise_templates()
generate_system_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/nlg/affective_nlg.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
bc_nlg
¶
Handcrafted (i.e. template-based) Natural Language Generation Module with backchannel
BackchannelHandcraftedNLG (HandcraftedNLG)
¶
Handcrafted (i.e. template-based) Natural Language Generation Module
A rule-based approach on natural language generation. The rules have to be specified within a template file using the ADVISER NLG syntax. Python methods that are called within a template file must be specified in the HandcraftedNLG class by using the prefix "template". For example, the method "_template_genitive_s" can be accessed in the template file via calling {genitive_s(name)}
Attributes:
Name | Type | Description |
---|---|---|
domain |
Domain |
the domain |
template_filename |
str |
the NLG template filename |
templates |
TemplateFile |
the parsed and ready-to-go NLG template file |
template_english |
str |
the name of the English NLG template file |
template_german |
str |
the name of the German NLG template file |
language |
Language |
the language of the dialogue |
Source code in adviser/services/nlg/bc_nlg.py
class BackchannelHandcraftedNLG(HandcraftedNLG):
"""Handcrafted (i.e. template-based) Natural Language Generation Module
A rule-based approach on natural language generation.
The rules have to be specified within a template file using the ADVISER NLG syntax.
Python methods that are called within a template file must be specified in the
HandcraftedNLG class by using the prefix "_template_". For example, the method
"_template_genitive_s" can be accessed in the template file via calling {genitive_s(name)}
Attributes:
domain (Domain): the domain
template_filename (str): the NLG template filename
templates (TemplateFile): the parsed and ready-to-go NLG template file
template_english (str): the name of the English NLG template file
template_german (str): the name of the German NLG template file
language (Language): the language of the dialogue
"""
def __init__(self, domain: Domain, sub_topic_domains: Dict[str, str] = {}, template_file: str = None,
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
language: Language = None):
"""Constructor mainly extracts methods and rules from the template file"""
HandcraftedNLG.__init__(
self, domain, template_file=None,
logger=DiasysLogger(), template_file_german=None,
language=None, sub_topic_domains=sub_topic_domains)
# class_int_mapping = {0: b'no_bc', 1: b'assessment', 2: b'continuer'}
self.backchannels = {
0: [''],
1: ['Okay. ', 'Yeah. '],
2: ['Um-hum. ', 'Uh-huh. ']
}
@PublishSubscribe(sub_topics=["sys_act", 'predicted_BC'], pub_topics=["sys_utterance"])
def publish_system_utterance(self, sys_act: SysAct = None, predicted_BC: int = None) -> dict(sys_utterance=str):
"""
Takes a system act, searches for a fitting rule, adds, backchannel and applies it
and returns the message.
mapping = {0: b'no_bc', 1: b'assessment', 2: b'continuer'}
Args:
sys_act (SysAct): The system act, to check whether the dialogue was finished
predicted_BC (int): integer representation of the BC
Returns:
dict: a dict containing the system utterance
"""
rule_found = True
message = self.generate_system_utterance(sys_act)
if 'Sorry' not in message:
message = self.backchannels[predicted_BC][0] + message
return {'sys_utterance': message}
__init__(self, domain, sub_topic_domains={}, template_file=None, logger=<DiasysLogger adviser (NOTSET)>, template_file_german=None, language=None)
special
¶
Source code in adviser/services/nlg/bc_nlg.py
def __init__(self, domain: Domain, sub_topic_domains: Dict[str, str] = {}, template_file: str = None,
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
language: Language = None):
"""Constructor mainly extracts methods and rules from the template file"""
HandcraftedNLG.__init__(
self, domain, template_file=None,
logger=DiasysLogger(), template_file_german=None,
language=None, sub_topic_domains=sub_topic_domains)
# class_int_mapping = {0: b'no_bc', 1: b'assessment', 2: b'continuer'}
self.backchannels = {
0: [''],
1: ['Okay. ', 'Yeah. '],
2: ['Um-hum. ', 'Uh-huh. ']
}
publish_system_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/nlg/bc_nlg.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
nlg
¶
Handcrafted (i.e. template-based) Natural Language Generation Module
HandcraftedNLG (Service)
¶
Handcrafted (i.e. template-based) Natural Language Generation Module
A rule-based approach on natural language generation. The rules have to be specified within a template file using the ADVISER NLG syntax. Python methods that are called within a template file must be specified in the HandcraftedNLG class by using the prefix "template". For example, the method "_template_genitive_s" can be accessed in the template file via calling {genitive_s(name)}
Attributes:
Name | Type | Description |
---|---|---|
domain |
Domain |
the domain |
template_filename |
str |
the NLG template filename |
templates |
TemplateFile |
the parsed and ready-to-go NLG template file |
template_english |
str |
the name of the English NLG template file |
template_german |
str |
the name of the German NLG template file |
language |
Language |
the language of the dialogue |
Source code in adviser/services/nlg/nlg.py
class HandcraftedNLG(Service):
"""Handcrafted (i.e. template-based) Natural Language Generation Module
A rule-based approach on natural language generation.
The rules have to be specified within a template file using the ADVISER NLG syntax.
Python methods that are called within a template file must be specified in the
HandcraftedNLG class by using the prefix "_template_". For example, the method
"_template_genitive_s" can be accessed in the template file via calling {genitive_s(name)}
Attributes:
domain (Domain): the domain
template_filename (str): the NLG template filename
templates (TemplateFile): the parsed and ready-to-go NLG template file
template_english (str): the name of the English NLG template file
template_german (str): the name of the German NLG template file
language (Language): the language of the dialogue
"""
def __init__(self, domain: Domain, template_file: str = None, sub_topic_domains: Dict[str, str] = {},
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
language: Language = None):
"""Constructor mainly extracts methods and rules from the template file"""
Service.__init__(self, domain=domain, sub_topic_domains=sub_topic_domains)
self.language = language if language else Language.ENGLISH
self.template_english = template_file
# TODO: at some point if we expand languages, maybe make kwargs? --LV
self.template_german = template_file_german
self.domain = domain
self.template_filename = None
self.templates = None
self.logger = logger
self.language = Language.ENGLISH
self._initialise_language(self.language)
@PublishSubscribe(sub_topics=["sys_act"], pub_topics=["sys_utterance"])
def publish_system_utterance(self, sys_act: SysAct = None) -> dict(sys_utterance=str):
"""Generates the system utterance and publishes it.
Args:
sys_act (SysAct): The system act published by the policy
Returns:
dict: a dict containing the system utterance
"""
return {'sys_utterance': self.generate_system_utterance(sys_act)}
def generate_system_utterance(self, sys_act: SysAct = None) -> str:
"""Main function of the NLG module
Takes a system act, searches for a fitting rule, applies it and returns the message.
Overwrite this function if you inherit from the NLG module.
Args:
sys_act (SysAct): The system act
Returns:
The utterance generated by applying a fitting template
"""
rule_found = True
message = ""
try:
message = self.templates.create_message(sys_act)
except BaseException as error:
rule_found = False
self.logger.error(error)
raise(error)
# inform if no applicable rule could be found in the template file
if not rule_found:
self.logger.info('Could not find a fitting rule for the given system act!')
self.logger.info("System Action: " + str(sys_act.type)
+ " - Slots: " + str(sys_act.slot_values))
# self.logger.dialog_turn("System Action: " + message)
return message
def _initialise_language(self, language: Language):
"""
Loads the correct template file based on which language has been selected
this should only be called on the first turn of the dialog
Args:
language (Language): Enum representing the language the user has selected
"""
if language == Language.ENGLISH:
if self.template_english is None:
self.template_filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'../../resources/nlg_templates/%sMessages.nlg' % self.domain.get_domain_name())
else:
self.template_filename = self.template_english
if language == Language.GERMAN:
if self.template_german is None:
self.template_filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'../../resources/nlg_templates/{}MessagesGerman.nlg'.format(
self.domain.get_domain_name()))
else:
self.template_filename = self.template_german
self.templates = TemplateFile(self.template_filename, self.domain)
self._add_additional_methods_for_template_file()
def _add_additional_methods_for_template_file(self):
"""add the function prefixed by "_template_" to the template file interpreter"""
for (method_name, method) in inspect.getmembers(type(self), inspect.isfunction):
if method_name.startswith('_template_'):
self.templates.add_python_function(method_name[10:], method, [self])
def _template_genitive_s(self, name: str) -> str:
if name[-1] == 's':
return f"{name}'"
else:
return f"{name}'s"
def _template_genitive_s_german(self, name: str) -> str:
if name[-1] in ('s', 'x', 'ß', 'z'):
return f"{name}'"
else:
return f"{name}s"
__init__(self, domain, template_file=None, sub_topic_domains={}, logger=<DiasysLogger adviser (NOTSET)>, template_file_german=None, language=None)
special
¶
Constructor mainly extracts methods and rules from the template file
Source code in adviser/services/nlg/nlg.py
def __init__(self, domain: Domain, template_file: str = None, sub_topic_domains: Dict[str, str] = {},
logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
language: Language = None):
"""Constructor mainly extracts methods and rules from the template file"""
Service.__init__(self, domain=domain, sub_topic_domains=sub_topic_domains)
self.language = language if language else Language.ENGLISH
self.template_english = template_file
# TODO: at some point if we expand languages, maybe make kwargs? --LV
self.template_german = template_file_german
self.domain = domain
self.template_filename = None
self.templates = None
self.logger = logger
self.language = Language.ENGLISH
self._initialise_language(self.language)
generate_system_utterance(self, sys_act=None)
¶
Main function of the NLG module
Takes a system act, searches for a fitting rule, applies it and returns the message. Overwrite this function if you inherit from the NLG module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sys_act |
SysAct |
The system act |
None |
Returns:
Type | Description |
---|---|
str |
The utterance generated by applying a fitting template |
Source code in adviser/services/nlg/nlg.py
def generate_system_utterance(self, sys_act: SysAct = None) -> str:
"""Main function of the NLG module
Takes a system act, searches for a fitting rule, applies it and returns the message.
Overwrite this function if you inherit from the NLG module.
Args:
sys_act (SysAct): The system act
Returns:
The utterance generated by applying a fitting template
"""
rule_found = True
message = ""
try:
message = self.templates.create_message(sys_act)
except BaseException as error:
rule_found = False
self.logger.error(error)
raise(error)
# inform if no applicable rule could be found in the template file
if not rule_found:
self.logger.info('Could not find a fitting rule for the given system act!')
self.logger.info("System Action: " + str(sys_act.type)
+ " - Slots: " + str(sys_act.slot_values))
# self.logger.dialog_turn("System Action: " + message)
return message
publish_system_utterance(self, *args, **kwargs)
¶
Source code in adviser/services/nlg/nlg.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
templates
special
¶
builtinfunctions
¶
ForEntryFunction (Function)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
class ForEntryFunction(Function):
def __init__(self, global_memory):
Function.__init__(self, 'for_entry(slots, function, separator_first, separator_last)')
self.global_memory = global_memory
def is_applicable(self, parameters: Memory) -> bool:
return len(parameters.variables) >= 4
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[4:]]
texts: List[str] = []
for slot_value_pair in parameters.variables[0].value:
memory = self._build_memory(slot_value_pair[0], slot_value_pair[1], extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be called '
f'from the for_entry function')
texts.append(function.apply(memory))
return self._create_text_from_elements(texts, parameters.variables[2].value,
parameters.variables[3].value)
def _build_memory(self, slot: str, value: str, arguments: List[str]):
memory = Memory(self.global_memory)
memory.add_variable(Variable('slot', slot))
memory.add_variable(Variable('value', value))
for i, argument in enumerate(arguments):
memory.add_variable(Variable(f'arg{i}', argument))
return memory
def _create_text_from_elements(self, elements: List[str], separator: str, last_separator: str):
if not elements:
return ''
if len(elements) == 1:
return elements[0]
text = elements[0]
for i in range(1, len(elements) - 1):
text += separator + elements[i]
text += last_separator + elements[-1]
return text
__init__(self, global_memory)
special
¶
apply(self, parameters=None)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[4:]]
texts: List[str] = []
for slot_value_pair in parameters.variables[0].value:
memory = self._build_memory(slot_value_pair[0], slot_value_pair[1], extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be called '
f'from the for_entry function')
texts.append(function.apply(memory))
return self._create_text_from_elements(texts, parameters.variables[2].value,
parameters.variables[3].value)
is_applicable(self, parameters)
¶
ForEntryListFunction (Function)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
class ForEntryListFunction(Function):
def __init__(self, global_memory: GlobalMemory):
Function.__init__(self, 'for_entry_list(slots, function, value_sep, value_sep_last, '
'slot_sep, slot_sep_last)')
self.global_memory = global_memory
def is_applicable(self, parameters: Memory) -> bool:
return len(parameters.variables) >= 6
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[6:]]
texts_per_slot: List[str] = []
for slot_values_pair in parameters.variables[0].value:
slot_texts: List[str] = []
for value in slot_values_pair[1]:
memory = self._build_memory(slot_values_pair[0], value, extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be '
f'called from the for_entry_list function')
slot_texts.append(function.apply(memory))
text = self._create_text_from_elements(slot_texts, parameters.variables[2].value,
parameters.variables[3].value)
texts_per_slot.append(text)
return self._create_text_from_elements(texts_per_slot, parameters.variables[4].value,
parameters.variables[5].value)
def _build_memory(self, slot: str, value: str, arguments: List[str]):
memory = Memory(self.global_memory)
memory.add_variable(Variable('slot', slot))
memory.add_variable(Variable('value', value))
for i, argument in enumerate(arguments):
memory.add_variable(Variable(f'arg{i}', argument))
return memory
def _create_text_from_elements(self, elements: List[str], separator: str, last_separator: str):
if not elements:
return ''
if len(elements) == 1:
return elements[0]
text = elements[0]
for i in range(1, len(elements) - 1):
text += separator + elements[i]
text += last_separator + elements[-1]
return text
__init__(self, global_memory)
special
¶
apply(self, parameters=None)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[6:]]
texts_per_slot: List[str] = []
for slot_values_pair in parameters.variables[0].value:
slot_texts: List[str] = []
for value in slot_values_pair[1]:
memory = self._build_memory(slot_values_pair[0], value, extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be '
f'called from the for_entry_list function')
slot_texts.append(function.apply(memory))
text = self._create_text_from_elements(slot_texts, parameters.variables[2].value,
parameters.variables[3].value)
texts_per_slot.append(text)
return self._create_text_from_elements(texts_per_slot, parameters.variables[4].value,
parameters.variables[5].value)
is_applicable(self, parameters)
¶
ForFunction (Function)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
class ForFunction(Function):
def __init__(self, global_memory):
Function.__init__(self, 'for(values, function, separator_first, separator_last, *args)')
self.global_memory = global_memory
def is_applicable(self, parameters: Memory) -> bool:
return len(parameters.variables) >= 4
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[4:]]
texts: List[str] = []
for value in parameters.variables[0].value:
memory = self._build_memory(value, extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be called '
f'from the for function')
texts.append(function.apply(memory))
return self._create_text_from_elements(texts, parameters.variables[2].value,
parameters.variables[3].value)
def _build_memory(self, value: str, arguments: List[str]):
memory = Memory(self.global_memory)
memory.add_variable(Variable('value', value))
for i, argument in enumerate(arguments):
memory.add_variable(Variable(f'arg{i}', argument))
return memory
def _create_text_from_elements(self, elements: List[str], separator: str, last_separator: str):
if not elements:
return ''
if len(elements) == 1:
return elements[0]
text = elements[0]
for i in range(1, len(elements) - 1):
text += separator + elements[i]
text += last_separator + elements[-1]
return text
__init__(self, global_memory)
special
¶
apply(self, parameters=None)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
def apply(self, parameters: Memory = None) -> str:
function = parameters.get_function(parameters.variables[1].value)
extra_arguments = [variable.value for variable in parameters.variables[4:]]
texts: List[str] = []
for value in parameters.variables[0].value:
memory = self._build_memory(value, extra_arguments)
if not function.is_applicable(memory):
raise BaseException(f'The function {function.function_name} could not be called '
f'from the for function')
texts.append(function.apply(memory))
return self._create_text_from_elements(texts, parameters.variables[2].value,
parameters.variables[3].value)
is_applicable(self, parameters)
¶
PythonFunction (Function)
¶
Source code in adviser/services/nlg/templates/builtinfunctions.py
class PythonFunction(Function):
def __init__(self, function_name: str, function_to_call: Callable,
obligatory_arguments: List[object] = []):
Function.__init__(self, f'{function_name}()')
self.function = function_to_call
self.obligatory_arguments = obligatory_arguments
def is_applicable(self, parameters: Memory) -> bool:
return True
def apply(self, parameters: Memory = None) -> str:
arguments = self.obligatory_arguments.copy()
arguments.extend([variable.value for variable in parameters.variables])
return self.function(*arguments)
parsing
special
¶
automaton
¶
ModifiedPushdownAutomaton
¶
Source code in adviser/services/nlg/templates/parsing/automaton.py
class ModifiedPushdownAutomaton:
def __init__(self, start_state: State, accept_states: List[State],
state_descriptions: List[StateDescription]):
self.start_state = start_state
self.accept_states = accept_states
self.state_descriptions = state_descriptions
self.state_transition_mapping = self._create_state_transition_mapping()
self.state_default_transition_mapping = self._create_state_default_transition_mapping()
self.stack = AutomatonStack()
def _create_state_transition_mapping(self) -> Dict[State, Dict[str, Transition]]:
state_transition_mapping = {}
for state_description in self.state_descriptions:
input_state = state_description.default_state
if input_state not in state_transition_mapping:
state_transition_mapping[input_state] = {}
for transition in state_description.transitions:
input_char = transition.input_configuration.character
state_transition_mapping[input_state][input_char] = transition
return state_transition_mapping
def _create_state_default_transition_mapping(self) -> Dict[State, DefaultTransition]:
state_default_transition_mapping = {}
for state_description in self.state_descriptions:
state_default_transition_mapping[state_description.default_state] = \
state_description.default_transition
return state_default_transition_mapping
def parse(self, input_tape: str) -> List[object]:
self.stack.clear()
current_state = self.start_state
input_tape_index = 0
for input_char in input_tape:
try:
configuration = Configuration(current_state, input_char)
transition = self._find_transition(configuration)
current_state = self._apply_transition(transition, configuration)
input_tape_index += 1
except ParsingError as error:
print('State:', current_state.name)
print('Index:', input_tape_index)
print('Original Input:', input_tape)
raise error
if current_state not in self.accept_states:
print('State:', current_state.name)
raise ParsingError(f'Parser was not in a final state after the input tape was read.')
return self.stack.data_stack[:]
def _apply_transition(self, transition: Transition,
input_configuration: Configuration) -> State:
transition.perform_stack_action(self.stack, input_configuration)
output_configuration = transition.get_output_configuration(input_configuration)
self.stack.add_char(output_configuration.character)
return output_configuration.state
def _find_transition(self, configuration: Configuration):
if configuration.state not in self.state_transition_mapping or \
configuration.character not in self.state_transition_mapping[configuration.state]:
return self._find_default_transition(configuration.state)
return self.state_transition_mapping[configuration.state][configuration.character]
def _find_default_transition(self, current_state: State):
if current_state not in self.state_default_transition_mapping:
raise ParsingError(f'No default transition found for state {current_state.name}.')
return self.state_default_transition_mapping.get(current_state, None)
__init__(self, start_state, accept_states, state_descriptions)
special
¶Source code in adviser/services/nlg/templates/parsing/automaton.py
def __init__(self, start_state: State, accept_states: List[State],
state_descriptions: List[StateDescription]):
self.start_state = start_state
self.accept_states = accept_states
self.state_descriptions = state_descriptions
self.state_transition_mapping = self._create_state_transition_mapping()
self.state_default_transition_mapping = self._create_state_default_transition_mapping()
self.stack = AutomatonStack()
parse(self, input_tape)
¶Source code in adviser/services/nlg/templates/parsing/automaton.py
def parse(self, input_tape: str) -> List[object]:
self.stack.clear()
current_state = self.start_state
input_tape_index = 0
for input_char in input_tape:
try:
configuration = Configuration(current_state, input_char)
transition = self._find_transition(configuration)
current_state = self._apply_transition(transition, configuration)
input_tape_index += 1
except ParsingError as error:
print('State:', current_state.name)
print('Index:', input_tape_index)
print('Original Input:', input_tape)
raise error
if current_state not in self.accept_states:
print('State:', current_state.name)
raise ParsingError(f'Parser was not in a final state after the input tape was read.')
return self.stack.data_stack[:]
configuration
¶
Configuration
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
__init__(self, state, character)
special
¶
DefaultTransition (Transition)
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
__init__(self, state)
special
¶
SimpleForwardDefaultTransition (DefaultTransition)
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
class SimpleForwardDefaultTransition(DefaultTransition):
def __init__(self, state: State):
DefaultTransition.__init__(self, state)
def get_output_configuration(self, input_configuration: Configuration) -> Configuration:
return Configuration(input_configuration.state, input_configuration.character)
def perform_stack_action(self, stack: AutomatonStack, configuration: Configuration):
pass
State
¶
StateDescription
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
__init__(self, default_state, default_transition, transitions)
special
¶
Transition
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
class Transition:
def __init__(self, input_configuration: Configuration):
self.input_configuration = input_configuration
def get_output_configuration(self, input_configuration: Configuration) -> Configuration:
raise NotImplementedError()
def perform_stack_action(self, stack: AutomatonStack, configuration: Configuration):
raise NotImplementedError()
TransitionWithAction (Transition)
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
class TransitionWithAction(Transition):
def __init__(self, input_configuration: Configuration, output_configuration: Configuration,
action: Callable[[AutomatonStack], None]):
Transition.__init__(self, input_configuration)
self.output_configuration = output_configuration
self.action = action
def get_output_configuration(self, input_configuration: Configuration) -> Configuration:
return self.output_configuration
def perform_stack_action(self, stack: AutomatonStack, configuration: Configuration):
self.action(stack)
TransitionWithoutAction (Transition)
¶
Source code in adviser/services/nlg/templates/parsing/configuration.py
class TransitionWithoutAction(Transition):
def __init__(self, input_configuration: Configuration, output_configuration: Configuration):
Transition.__init__(self, input_configuration)
self.output_configuration = output_configuration
def get_output_configuration(self, input_configuration: Configuration) -> Configuration:
return self.output_configuration
def perform_stack_action(self, stack: AutomatonStack, configuration: Configuration):
pass
exceptions
¶
parsers
special
¶
messageparser
special
¶
messageparser
¶head_location
¶
MessageParser (ModifiedPushdownAutomaton)
¶Source code in adviser/services/nlg/templates/parsing/parsers/messageparser/messageparser.py
class MessageParser(ModifiedPushdownAutomaton):
def __init__(self):
ModifiedPushdownAutomaton.__init__(self, StartState(), [AcceptState()], [
StartStateDescription(),
AcceptStateDescription(),
MessageStateDescription(),
EscapeStateDescription(),
CodeStateDescription(),
AdviserStateDescription(),
PythonStateDescription(),
PythonClosingBraceStateDescription(),
CodeStringStateDescription()
])
__init__(self)
special
¶Source code in adviser/services/nlg/templates/parsing/parsers/messageparser/messageparser.py
def __init__(self):
ModifiedPushdownAutomaton.__init__(self, StartState(), [AcceptState()], [
StartStateDescription(),
AcceptStateDescription(),
MessageStateDescription(),
EscapeStateDescription(),
CodeStateDescription(),
AdviserStateDescription(),
PythonStateDescription(),
PythonClosingBraceStateDescription(),
CodeStringStateDescription()
])
stack
¶
AutomatonStack
¶
Source code in adviser/services/nlg/templates/parsing/stack.py
class AutomatonStack:
def __init__(self):
# self.char_stack = [] # the automaton's stack
self.data_stack = [] # the stack in which custom data structures can be stored
self.levels = [[]] # multiple automaton stacks are possible here
def add_char(self, stack_char: str):
if not self.levels:
raise ParsingError('No more levels left on the stack')
self.levels[-1].append(stack_char)
def add_data(self, data: object):
self.data_stack.append(data)
def pop_data(self) -> object:
return self.data_stack.pop(-1)
def fetch_data(self) -> object:
return self.data_stack[-1]
def add_level(self):
self.levels.append([])
def get_current_content(self) -> str:
if not self.levels:
raise ParsingError('No more levels left on the stack')
return ''.join(self.levels[-1])
def remove_level(self):
if not self.levels:
raise ParsingError('No more levels to remove from the stack')
self.levels.pop()
def clear(self):
self.data_stack = []
self.levels = [[]]
preprocessing
¶
templatefile
¶
KEYWORDS
¶
TemplateFile
¶
Interprets a template file
Source code in adviser/services/nlg/templates/templatefile.py
class TemplateFile:
"""Interprets a template file
Attributes:
global_memory {GlobalMemory} -- memory that can be accessed at all times in the tempaltes
"""
def __init__(self, filename: str, domain: JSONLookupDomain):
self.global_memory = GlobalMemory(domain)
self._add_built_in_functions()
tfr = _TemplateFileReader(filename)
self._templates = self._create_template_dict(tfr.get_templates())
self._add_functions_to_global_memory(tfr.get_functions())
def _add_built_in_functions(self):
self.global_memory.add_function(ForFunction(self.global_memory))
self.global_memory.add_function(ForEntryFunction(self.global_memory))
self.global_memory.add_function(ForEntryListFunction(self.global_memory))
def _create_template_dict(self, templates: List[Template]) -> Dict[str, Template]:
template_dict = {}
for template in templates:
if template.intent not in template_dict:
template_dict[template.intent] = []
template_dict[template.intent].append(template)
return template_dict
def _add_functions_to_global_memory(self, functions: List[Function]):
for function in functions:
self.global_memory.add_function(function)
def create_message(self, sys_act: SysAct) -> str:
"""Iterates through all possible templates and applies the first one to fit the system act
Arguments:
sys_act {SysAct} -- the system act to find a template for
Raises:
BaseException: when no template could be applied
Returns:
str -- the message returned by the template
"""
slots = self._create_memory_from_sys_act(sys_act)
for template in self._templates[sys_act.type.value]:
if template.is_applicable(slots):
return template.apply(slots)
raise BaseException('No template was found for the given system act.')
def _create_memory_from_sys_act(self, sys_act: SysAct) -> Memory:
slots = Memory(self.global_memory)
for slot in sys_act.slot_values:
slots.add_variable(Variable(slot, sys_act.slot_values[slot]))
return slots
def add_python_function(self, function_name: str, python_function: Callable[[object], str],
obligatory_arguments: List[object] = []):
"""Add a python function to the global memory of the template file interpreter
Arguments:
function_name {str} -- name under which the function can be accessed in template file
python_function {Callable[[object], str]} -- python function which is called when being
accessed in the template file
Keyword Arguments:
obligatory_arguments {List[object]} -- objects that are always passed as first
arguments to the python function, e.g. "self" (default: {[]})
"""
self.global_memory.add_function(PythonFunction(function_name, python_function,
obligatory_arguments))
__init__(self, filename, domain)
special
¶
Source code in adviser/services/nlg/templates/templatefile.py
add_python_function(self, function_name, python_function, obligatory_arguments=[])
¶
Add a python function to the global memory of the template file interpreter
Keyword arguments:
Name | Type | Description |
---|---|---|
obligatory_arguments |
{List[object]} -- objects that are always passed as first
arguments to the python function, e.g. "self" (default |
{[]}) |
Source code in adviser/services/nlg/templates/templatefile.py
def add_python_function(self, function_name: str, python_function: Callable[[object], str],
obligatory_arguments: List[object] = []):
"""Add a python function to the global memory of the template file interpreter
Arguments:
function_name {str} -- name under which the function can be accessed in template file
python_function {Callable[[object], str]} -- python function which is called when being
accessed in the template file
Keyword Arguments:
obligatory_arguments {List[object]} -- objects that are always passed as first
arguments to the python function, e.g. "self" (default: {[]})
"""
self.global_memory.add_function(PythonFunction(function_name, python_function,
obligatory_arguments))
create_message(self, sys_act)
¶
Iterates through all possible templates and applies the first one to fit the system act
Exceptions:
Type | Description |
---|---|
BaseException |
when no template could be applied |
Returns:
Type | Description |
---|---|
str |
str -- the message returned by the template |
Source code in adviser/services/nlg/templates/templatefile.py
def create_message(self, sys_act: SysAct) -> str:
"""Iterates through all possible templates and applies the first one to fit the system act
Arguments:
sys_act {SysAct} -- the system act to find a template for
Raises:
BaseException: when no template could be applied
Returns:
str -- the message returned by the template
"""
slots = self._create_memory_from_sys_act(sys_act)
for template in self._templates[sys_act.type.value]:
if template.is_applicable(slots):
return template.apply(slots)
raise BaseException('No template was found for the given system act.')
nlu
special
¶
nlu
¶
HandcraftedNLU (Service)
¶
Class for Handcrafted Natural Language Understanding Module (HDC-NLU).
HDC-NLU is a rule-based approach to recognize the user acts as well as their respective slots and values from the user input (i.e. text) by means of regular expressions.
HDC-NLU is domain-independet. The regular expressions of are read from JSON files.
There exist a JSON file that stores general rules (GeneralRules.json), i.e. rules that apply to any domain, e.g. rules to detect salutation (Hello, Hi).
There are two more files per domain that contain the domain-specific rules for request and inform user acts, e.g. ImsCoursesInformRules.json and ImsCoursesRequestRules.json.
The output during dialog interaction of this module is a semantic representation of the user input.
"I am looking for pizza" --> inform(slot=food,value=italian)
See the regex_generator under tools, if the existing regular expressions need to be changed or a new domain should be added.
Source code in adviser/services/nlu/nlu.py
class HandcraftedNLU(Service):
"""
Class for Handcrafted Natural Language Understanding Module (HDC-NLU).
HDC-NLU is a rule-based approach to recognize the user acts as well as
their respective slots and values from the user input (i.e. text)
by means of regular expressions.
HDC-NLU is domain-independet. The regular expressions of are read
from JSON files.
There exist a JSON file that stores general rules (GeneralRules.json),
i.e. rules that apply to any domain, e.g. rules to detect salutation (Hello, Hi).
There are two more files per domain that contain the domain-specific rules
for request and inform user acts, e.g. ImsCoursesInformRules.json and
ImsCoursesRequestRules.json.
The output during dialog interaction of this module is a semantic
representation of the user input.
"I am looking for pizza" --> inform(slot=food,value=italian)
See the regex_generator under tools, if the existing regular expressions
need to be changed or a new domain should be added.
"""
def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
language: Language = None):
"""
Loads
- domain key
- informable slots
- requestable slots
- domain-independent regular expressions
- domain-specific regualer espressions
It sets the previous system act to None
Args:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
"""
Service.__init__(self, domain=domain)
self.logger = logger
self.language = language if language else Language.ENGLISH
# Getting domain information
self.domain_name = domain.get_domain_name()
self.domain_key = domain.get_primary_key()
# Getting lists of informable and requestable slots
self.USER_INFORMABLE = domain.get_informable_slots()
self.USER_REQUESTABLE = domain.get_requestable_slots()
# Getting the relative path where regexes are stored
self.base_folder = os.path.join(get_root_dir(), 'resources', 'nlu_regexes')
# Setting previous system act to None to signal the first turn
# self.prev_sys_act = None
self.sys_act_info = {
'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}
self.language = Language.ENGLISH
self._initialize()
def dialog_start(self) -> dict:
"""
Sets the previous system act as None.
This function is called when the dialog starts
Returns:
Empty dictionary
"""
self.sys_act_info = {
'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}
self.user_acts = []
self.slots_informed = set()
self.slots_requested = set()
self.req_everything = False
@PublishSubscribe(sub_topics=["user_utterance"], pub_topics=["user_acts"])
def extract_user_acts(self, user_utterance: str = None) -> dict(user_acts=List[UserAct]):
"""
Responsible for detecting user acts with their respective slot-values from the user
utterance through regular expressions.
Args:
user_utterance (BeliefState) - a BeliefState obejct representing current system
knowledge
Returns:
dict of str: UserAct - a dictionary with the key "user_acts" and the value
containing a list of user actions
"""
result = {}
# Setting request everything to False at every turn
self.req_everything = False
self.user_acts = []
# slots_requested & slots_informed store slots requested and informed in this turn
# they are used later for later disambiguation
self.slots_requested, self.slots_informed = set(), set()
if user_utterance is not None:
user_utterance = user_utterance.strip()
self._match_general_act(user_utterance)
self._match_domain_specific_act(user_utterance)
self._solve_informable_values()
# If nothing else has been matched, see if the user chose a domain; otherwise if it's
# not the first turn, it's a bad act
if len(self.user_acts) == 0:
if self.domain.get_keyword() in user_utterance:
self.user_acts.append(UserAct(text=user_utterance if user_utterance else "",
act_type=UserActionType.SelectDomain))
elif self.sys_act_info['last_act'] is not None:
# start of dialogue or no regex matched
self.user_acts.append(UserAct(text=user_utterance if user_utterance else "",
act_type=UserActionType.Bad))
self._assign_scores()
self.logger.dialog_turn("User Actions: %s" % str(self.user_acts))
result['user_acts'] = self.user_acts
return result
@PublishSubscribe(sub_topics=["sys_state"])
def _update_sys_act_info(self, sys_state):
if "lastInformedPrimKeyVal" in sys_state:
self.sys_act_info['last_offer'] = sys_state['lastInformedPrimKeyVal']
if "lastRequestSlot" in sys_state:
self.sys_act_info['last_request'] = sys_state['lastRequestSlot']
if "last_act" in sys_state:
self.sys_act_info['last_act'] = sys_state['last_act']
def _match_general_act(self, user_utterance: str):
"""
Finds general acts (e.g. Hello, Bye) in the user input
Args:
user_utterance {str} -- text input from user
Returns:
"""
# Iteration over all general acts
for act in self.general_regex:
# Check if the regular expression and the user utterance match
if re.search(self.general_regex[act], user_utterance, re.I):
# Mapping the act to User Act
if act != 'dontcare' and act != 'req_everything':
user_act_type = UserActionType(act)
else:
user_act_type = act
# Check if the found user act is affirm or deny
if self.sys_act_info['last_act'] and (user_act_type == UserActionType.Affirm or
user_act_type == UserActionType.Deny):
# Conditions to check the history in order to assign affirm or deny
# slots mentioned in the previous system act
# Check if the preceeding system act was confirm
if self.sys_act_info['last_act'].type == SysActionType.Confirm:
# Iterate over all slots in the system confimation
# and make a list of Affirm/Deny(slot=value)
# where value is taken from the previous sys act
for slot in self.sys_act_info['last_act'].slot_values:
# New user act -- Affirm/Deny(slot=value)
user_act = UserAct(act_type=UserActionType(act),
text=user_utterance,
slot=slot,
value=self.sys_act_info['last_act'].slot_values[slot])
self.user_acts.append(user_act)
# Check if the preceeding system act was request
# This covers the binary requests, e.g. 'Is the course related to Math?'
elif self.sys_act_info['last_act'].type == SysActionType.Request:
# Iterate over all slots in the system request
# and make a list of Inform(slot={True|False})
for slot in self.sys_act_info['last_act'].slot_values:
# Assign value for the slot mapping from Affirm or Request to Logical,
# True if user affirms, False if user denies
value = 'true' if user_act_type == UserActionType.Affirm else 'false'
# Adding user inform act
self._add_inform(user_utterance, slot, value)
# Check if Deny happens after System Request more, then trigger bye
elif self.sys_act_info['last_act'].type == SysActionType.RequestMore \
and user_act_type == UserActionType.Deny:
user_act = UserAct(text=user_utterance, act_type=UserActionType.Bye)
self.user_acts.append(user_act)
# Check if Request or Select is the previous system act
elif user_act_type == 'dontcare':
if self.sys_act_info['last_act'].type == SysActionType.Request or \
self.sys_act_info['last_act'].type == SysActionType.Select:
# Iteration over all slots mentioned in the last system act
for slot in self.sys_act_info['last_act'].slot_values:
# Adding user inform act
self._add_inform(user_utterance, slot, value=user_act_type)
# Check if the user wants to get all information about a particular entity
elif user_act_type == 'req_everything':
self.req_everything = True
else:
# This section covers all general user acts that do not depend on
# the dialog history
# New user act -- UserAct()
user_act = UserAct(act_type=user_act_type, text=user_utterance)
self.user_acts.append(user_act)
def _match_domain_specific_act(self, user_utterance: str):
"""
Matches in-domain user acts
Calls functions to find user requests and informs
Args:
user_utterance {str} -- text input from user
Returns:
"""
# Find Requests
self._match_request(user_utterance)
# Find Informs
self._match_inform(user_utterance)
def _match_request(self, user_utterance: str):
"""
Iterates over all user request regexes and find matches with the user utterance
Args:
user_utterance {str} -- text input from user
Returns:
"""
# Iteration over all user requestable slots
for slot in self.USER_REQUESTABLE:
if self._check(re.search(self.request_regex[slot], user_utterance, re.I)):
self._add_request(user_utterance, slot)
def _add_request(self, user_utterance: str, slot: str):
"""
Creates the user request act and adds it to the user act list
Args:
user_utterance {str} -- text input from user
slot {str} -- requested slot
Returns:
"""
# New user act -- Request(slot)
user_act = UserAct(text=user_utterance, act_type=UserActionType.Request, slot=slot)
self.user_acts.append(user_act)
# Storing user requested slots during the whole dialog
self.slots_requested.add(slot)
def _match_inform(self, user_utterance: str):
"""
Iterates over all user inform slot-value regexes and find matches with the user utterance
Args:
user_utterance {str} -- text input from user
Returns:
"""
# Iteration over all user informable slots and their slots
for slot in self.USER_INFORMABLE:
for value in self.inform_regex[slot]:
if self._check(re.search(self.inform_regex[slot][value], user_utterance, re.I)):
if slot == self.domain_key and self.req_everything:
# Adding all requestable slots because of the req_everything
for req_slot in self.USER_REQUESTABLE:
# skipping the domain key slot
if req_slot != self.domain_key:
# Adding user request act
self._add_request(user_utterance, req_slot)
# Adding user inform act
self._add_inform(user_utterance, slot, value)
def _add_inform(self, user_utterance: str, slot: str, value: str):
"""
Creates the user request act and adds it to the user act list
Args:
user_utterance {str} -- text input from user
slot {str} -- informed slot
value {str} -- value for the informed slot
Returns:
"""
user_act = UserAct(text=user_utterance, act_type=UserActionType.Inform,
slot=slot, value=value)
self.user_acts.append(user_act)
# Storing user informed slots in this turn
self.slots_informed.add(slot)
@staticmethod
def _exact_match(phrases: List[str], user_utterance: str) -> bool:
"""
Checks if the user utterance is exactly like one in the
Args:
phrases List[str] -- list of contextual don't cares
user_utterance {str} -- text input from user
Returns:
"""
# apostrophes are removed
if user_utterance.lstrip().lower().replace("'", "") in phrases:
return True
return False
def _match_affirm(self, user_utterance: str):
"""TO BE DEFINED AT A LATER POINT"""
pass
def _match_negative_inform(self, user_utterance: str):
"""TO BE DEFINED AT A LATER POINT"""
pass
@staticmethod
def _check(re_object) -> bool:
"""
Checks if the regular expression and the user utterance matched
Args:
re_object: output from re.search(...)
Returns:
True/False if match happened
"""
if re_object is None:
return False
for o in re_object.groups():
if o is not None:
return True
return False
def _assign_scores(self):
"""
Goes over the user act list, checks concurrencies and assign scores
Returns:
"""
for i in range(len(self.user_acts)):
# TODO: Create a clever and meaningful mechanism to assign scores
# Since the user acts are matched, they get 1.0 as score
self.user_acts[i].score = 1.0
def _disambiguate_co_occurrence(self, beliefstate: BeliefState):
# Check if there is user inform and request occur simultaneously for a binary slot
# E.g. request(applied_nlp) & inform(applied_nlp=true)
# Difficult to disambiguate using regexes
if self.slots_requested.intersection(self.slots_informed):
if beliefstate is None:
act_to_del = UserActionType.Request
elif self.sys_act_info['lastInformedPrimKeyVal'] in [None, '**NONE**', 'none']:
act_to_del = UserActionType.Request
else:
act_to_del = UserActionType.Inform
acts_to_del = []
for slot in self.slots_requested.intersection(self.slots_informed):
for i, user_act in enumerate(self.user_acts):
if user_act.type == act_to_del and user_act.slot == slot:
acts_to_del.append(i)
self.user_acts = [user_act for i, user_act in enumerate(self.user_acts)
if i not in acts_to_del]
def _solve_informable_values(self):
# Verify if two or more informable slots with the same value were caught
# Cases:
# If a system request precedes and the slot is the on of the two informable, keep that one.
# If there is no preceding request, take
informed_values = {}
for i, user_act in enumerate(self.user_acts):
if user_act.type == UserActionType.Inform:
if user_act.value != "true" and user_act.value != "false":
if user_act.value not in informed_values:
informed_values[user_act.value] = [(i, user_act.slot)]
else:
informed_values[user_act.value].append((i, user_act.slot))
informed_values = {value: informed_values[value] for value in informed_values if
len(informed_values[value]) > 1}
if "6" in informed_values:
self.user_acts = []
def _initialize(self):
"""
Loads the correct regex files based on which language has been selected
this should only be called on the first turn of the dialog
Args:
language (Language): Enum representing the language the user has selected
"""
if self.language == Language.ENGLISH:
# Loading regular expression from JSON files
# as dictionaries {act:regex, ...} or {slot:{value:regex, ...}, ...}
self.general_regex = json.load(open(self.base_folder + '/GeneralRules.json'))
self.request_regex = json.load(open(self.base_folder + '/' + self.domain_name
+ 'RequestRules.json'))
self.inform_regex = json.load(open(self.base_folder + '/' + self.domain_name
+ 'InformRules.json'))
elif self.language == Language.GERMAN:
# TODO: Change this once
# Loading regular expression from JSON files
# as dictionaries {act:regex, ...} or {slot:{value:regex, ...}, ...}
self.general_regex = json.load(open(self.base_folder + '/GeneralRulesGerman.json'))
self.request_regex = json.load(open(self.base_folder + '/' + self.domain_name
+ 'GermanRequestRules.json'))
self.inform_regex = json.load(open(self.base_folder + '/' + self.domain_name
+ 'GermanInformRules.json'))
else:
print('No language')
__init__(self, domain, logger=<DiasysLogger adviser (NOTSET)>, language=None)
special
¶
Loads - domain key - informable slots - requestable slots - domain-independent regular expressions - domain-specific regualer espressions
It sets the previous system act to None
Source code in adviser/services/nlu/nlu.py
def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
language: Language = None):
"""
Loads
- domain key
- informable slots
- requestable slots
- domain-independent regular expressions
- domain-specific regualer espressions
It sets the previous system act to None
Args:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
"""
Service.__init__(self, domain=domain)
self.logger = logger
self.language = language if language else Language.ENGLISH
# Getting domain information
self.domain_name = domain.get_domain_name()
self.domain_key = domain.get_primary_key()
# Getting lists of informable and requestable slots
self.USER_INFORMABLE = domain.get_informable_slots()
self.USER_REQUESTABLE = domain.get_requestable_slots()
# Getting the relative path where regexes are stored
self.base_folder = os.path.join(get_root_dir(), 'resources', 'nlu_regexes')
# Setting previous system act to None to signal the first turn
# self.prev_sys_act = None
self.sys_act_info = {
'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}
self.language = Language.ENGLISH
self._initialize()
dialog_start(self)
¶
Sets the previous system act as None. This function is called when the dialog starts
Returns:
Type | Description |
---|---|
dict |
Empty dictionary |
Source code in adviser/services/nlu/nlu.py
def dialog_start(self) -> dict:
"""
Sets the previous system act as None.
This function is called when the dialog starts
Returns:
Empty dictionary
"""
self.sys_act_info = {
'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}
self.user_acts = []
self.slots_informed = set()
self.slots_requested = set()
self.req_everything = False
extract_user_acts(self, *args, **kwargs)
¶
Source code in adviser/services/nlu/nlu.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
get_root_dir()
¶
policy
special
¶
affective_policy
¶
EmotionPolicy (Service)
¶
Module for deciding what type of emotional response/ engagement level of response, the system should give
Source code in adviser/services/policy/affective_policy.py
class EmotionPolicy(Service):
""" Module for deciding what type of emotional response/ engagement level of response, the system
should give
"""
def __init__(self, domain: JSONLookupDomain = None, logger: DiasysLogger = DiasysLogger()):
"""
Initializes the policy
Arguments:
domain (JSONLookupDomain): the domain that the affective policy should operate in
"""
self.first_turn = True
Service.__init__(self, domain=domain)
self.logger = logger
def dialog_start(self):
pass
@PublishSubscribe(sub_topics=["userstate"], pub_topics=["sys_emotion", "sys_engagement"])
def choose_sys_emotion(self, userstate: UserState = None)\
-> Dict[str, str]:
"""
This method maps observed user emotion and user engagement to the system's choices
for output emotion/engagement
Args:
userstate (UserState): a UserState obejct representing current system
knowledge of the user's emotional state and engagement
Returns:
(dict): a dictionary with the keys "sys_emotion" and "sys_engagement" and the
corresponding values
"""
return {"sys_emotion": userstate["emotion"]["category"].value,
"sys_engagement": userstate["engagement"].value}
__init__(self, domain=None, logger=<DiasysLogger adviser (NOTSET)>)
special
¶
Initializes the policy
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
JSONLookupDomain |
the domain that the affective policy should operate in |
None |
Source code in adviser/services/policy/affective_policy.py
choose_sys_emotion(self, *args, **kwargs)
¶
Source code in adviser/services/policy/affective_policy.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
dialog_start(self)
¶
policy_api
¶
HandcraftedPolicy (Service)
¶
Handcrafted policy for API domains
Differs from the default HandcraftedPolicy class by taking into account mandatory slots, i.e. slots which have to be informed about before an API can even be called. The class is currently a copy of an older version of the HandcraftedPolicy class with the required changes for API usage. The classes will probably be merged in the future.
Source code in adviser/services/policy/policy_api.py
class HandcraftedPolicy(Service):
"""Handcrafted policy for API domains
Differs from the default HandcraftedPolicy class by taking into account mandatory slots,
i.e. slots which have to be informed about before an API can even be called.
The class is currently a copy of an older version of the HandcraftedPolicy class with the
required changes for API usage.
The classes will probably be merged in the future.
"""
def __init__(self, domain: LookupDomain, logger: DiasysLogger = DiasysLogger()):
"""
Initializes the policy
Arguments:
domain {domain.lookupdomain.LookupDomain} -- Domain
"""
self.first_turn = True
Service.__init__(self, domain=domain)
self.last_action = None
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
self.domain_key = domain.get_primary_key()
self.logger = logger
@PublishSubscribe(sub_topics=["beliefstate"], pub_topics=["sys_act", "sys_state"])
def choose_sys_act(self, beliefstate: BeliefState = None, sys_act: SysAct = None)\
-> dict(sys_act=SysAct):
"""
Responsible for walking the policy through a single turn. Uses the current user
action and system belief state to determine what the next system action should be.
To implement an alternate policy, this method may need to be overwritten
Args:
belief_state (BeliefState): a BeliefState object representing current system
knowledge
user_acts (list): a list of UserAct objects mapped from the user's last utterance
sys_act (SysAct): this should be None
Returns:
(dict): a dictionary with the key "sys_act" and the value that of the systems next
action
"""
# variables for general (non-domain specific) actions
# self.turn = dialog_graph.num_turns
self.prev_sys_act = sys_act
self._remove_gen_actions(beliefstate)
sys_state = {}
# do nothing on the first turn --LV
if self.first_turn and not beliefstate['user_acts']:
self.first_turn = False
sys_act = SysAct()
sys_act.type = SysActionType.Welcome
return {'sys_act': sys_act}
elif UserActionType.Bad in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.Bad
# if the action is 'bye' tell system to end dialog
elif UserActionType.Bye in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.Bye
# if user only says thanks, ask if they want anything else
elif UserActionType.Thanks in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.RequestMore
# If user only says hello, request a random slot to move dialog along
elif UserActionType.Hello in beliefstate["user_acts"] or UserActionType.SelectDomain in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.Request
slot = self._get_open_slot(beliefstate)
sys_act.add_value(slot)
# prepare sys_state info
sys_state['lastRequestSlot'] = slot
# If we switch to the domain, start a new dialog
if UserActionType.SelectDomain in beliefstate["user_acts"]:
self.dialog_start()
self.first_turn = False
# handle domain specific actions
else:
sys_act, sys_state = self._next_action(beliefstate)
self.logger.dialog_turn("System Action: " + str(sys_act))
if 'last_act' not in sys_state:
sys_state['last_act'] = sys_act
return {'sys_act': sys_act, 'sys_state': sys_state}
def dialog_start(self):
self.first_turn = True
self.last_action = None
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
def _remove_gen_actions(self, beliefstate: BeliefState):
"""
Helper function to read through user action list and if necessary
delete filler actions (eg. Hello, thanks) when there are other non-filler
(eg. Inform, Request) actions from the user. Stores list of relevant actions
as a class variable
Args:
user_acts (list): a list of UserAct objects
"""
act_types_lst = beliefstate["user_acts"]
# These are filler actions, so if there are other non-filler acions, remove them from
# the list of action types
while len(act_types_lst) > 1:
if UserActionType.Thanks in act_types_lst:
act_types_lst.remove(UserActionType.Thanks)
elif UserActionType.Bad in act_types_lst:
act_types_lst.remove(UserActionType.Bad)
elif UserActionType.Hello in act_types_lst:
act_types_lst.remove(UserActionType.Hello)
else:
break
def _query_db(self, beliefstate: BeliefState):
"""Based on the constraints specified, uses self.domain to generate the appropriate type
of query for the database
Returns:
iterable: representing the results of the database lookup
--LV
"""
# determine if an entity has already been suggested or was mentioned by the user
name = self._get_name(beliefstate)
# if yes and the user is asking for info about a specific entity, generate a query to get
# that info for the slots they have specified
if name and beliefstate['requests']:
requested_slots = beliefstate['requests']
return self.domain.find_info_about_entity(name, requested_slots)
# otherwise, issue a query to find all entities which satisfy the constraints the user
# has given so far
else:
constraints, _ = self._get_constraints(beliefstate)
return self.domain.find_entities(constraints)
def _get_name(self, beliefstate: BeliefState):
"""Finds if an entity has been suggested by the system (in the form of an offer candidate)
or by the user (in the form of an InformByName act). If so returns the identifier for
it, otherwise returns None
Args:
belief_state (dict): dictionary tracking the current system beliefs
Return:
(str): Returns a string representing the current entity name
-LV
"""
name = None
prim_key = self.domain.get_primary_key()
if prim_key in beliefstate['informs']:
possible_names = beliefstate['informs'][prim_key]
name = sorted(possible_names.items(), key=lambda kv: kv[1], reverse=True)[0][0]
# if the user is trying to query by name
else:
if self.s_index < len(self.current_suggestions):
current_suggestion = self.current_suggestions[self.s_index]
if current_suggestion:
name = current_suggestion[self.domain_key]
return name
def _get_constraints(self, beliefstate: BeliefState):
"""Reads the belief state and extracts any user specified constraints and any constraints
the user indicated they don't care about, so the system knows not to ask about them
Args:
belief_state (dict): dictionary tracking the current system beliefs
Return:
(tuple): dict of user requested slot names and their values and list of slots the user
doesn't care about
--LV
"""
slots = {}
# parts of the belief state which don't contain constraints
dontcare = [slot for slot in beliefstate['informs'] if "dontcare" in beliefstate["informs"][slot]]
informs = beliefstate["informs"]
slots = {}
for slot in informs:
if slot not in dontcare:
for value in informs[slot]:
slots[slot] = value
return slots, dontcare
def _mandatory_requests_fulfilled(self, belief_state: BeliefState):
"""whether or not all mandatory slots have a value
Arguments:
beliefstate (BeliefState): dictionary tracking the current system beliefs
"""
filled_slots, _ = self._get_constraints(belief_state)
mandatory_slots = self.domain.get_mandatory_slots()
for slot in mandatory_slots:
if slot not in filled_slots:
return False
return True
def _get_open_mandatory_slot(self, belief_state: BeliefState):
"""
Args:
belief_state (dict): dictionary tracking the current system beliefs
Returns:
(str): a string representing a category the system might want more info on. If all
system requestables have been filled, return none
"""
filled_slots, _ = self._get_constraints(belief_state)
mandatory_slots = self.domain.get_mandatory_slots()
for slot in mandatory_slots:
if slot not in filled_slots:
return slot
return None
def _get_open_slot(self, belief_state: BeliefState):
"""For a hello statement we need to be able to figure out what slots the user has not yet
specified constraint for, this method returns one of those at random
Args:
belief_state (dict): dictionary tracking the current system beliefs
Returns:
(str): a string representing a category the system might want more info on. If all
system requestables have been filled, return none
"""
filled_slots, _ = self._get_constraints(belief_state)
requestable_slots = self.domain.get_system_requestable_slots()
for slot in requestable_slots:
if slot not in filled_slots:
return slot
return None
def _next_action(self, beliefstate: BeliefState):
"""Determines the next system action based on the current belief state and
previous action.
When implementing a new type of policy, this method MUST be rewritten
Args:
belief_state (HandCraftedBeliefState): system values on liklihood
of each possible state
Return:
(SysAct): the next system action
--LV
"""
sys_state = {}
# Assuming this happens only because domain is not actually active --LV
"""if UserActionType.Bad in beliefstate['user_acts'] or beliefstate['requests'] \
and not self._get_name(beliefstate):
sys_act = SysAct()
sys_act.type = SysActionType.Bad
return sys_act, {'last_action': sys_act}"""
if not self._mandatory_requests_fulfilled(beliefstate):
sys_act = SysAct()
sys_act.type = SysActionType.Request
sys_act.slot_values = {self._get_open_mandatory_slot(beliefstate): None}
return sys_act, {'last_action': sys_act}
elif UserActionType.RequestAlternatives in beliefstate['user_acts'] \
and not self._get_constraints(beliefstate)[0]:
sys_act = SysAct()
sys_act.type = SysActionType.Bad
return sys_act, {'last_action': sys_act}
elif self.domain.get_primary_key() in beliefstate['informs'] \
and not beliefstate['requests']:
sys_act = SysAct()
sys_act.type = SysActionType.InformByName
sys_act.add_value(self.domain.get_primary_key(), self._get_name(beliefstate))
return sys_act, {'last_action': sys_act}
# Otherwise we need to query the db to determine next action
results = self._query_db(beliefstate)
sys_act = self._raw_action(results, beliefstate)
# requests are fairly easy, if it's a request, return it directly
if sys_act.type == SysActionType.Request:
if len(list(sys_act.slot_values.keys())) > 0:
# update the belief state to reflec the slot we just asked about
sys_state['lastRequestSlot'] = list(sys_act.slot_values.keys())[0]
# belief_state['system']['lastRequestSlot'] = list(sys_act.slot_values.keys())[0]
# otherwise we need to convert a raw inform into a one with proper slots and values
elif sys_act.type == SysActionType.InformByName:
self._convert_inform(results, sys_act, beliefstate)
# update belief state to reflect the offer we just made
values = sys_act.get_values(self.domain.get_primary_key())
if values:
# belief_state['system']['lastInformedPrimKeyVal'] = values[0]
sys_state['lastInformedPrimKeyVal'] = values[0]
else:
sys_act.add_value(self.domain.get_primary_key(), 'none')
sys_state['last_act'] = sys_act
return (sys_act, sys_state)
def _raw_action(self, q_res: iter, beliefstate: BeliefState) -> SysAct:
"""Based on the output of the db query and the method, choose
whether next action should be request or inform
Args:
q_res (list): rows (list of dicts) returned by the issued sqlite3
query
method (str): the type of user action
('byprimarykey', 'byconstraints', 'byalternatives')
Returns:
(SysAct): SysAct object of appropriate type
--LV
"""
sys_act = SysAct()
# if there is more than one result
if len(q_res) > 1:
constraints, dontcare = self._get_constraints(beliefstate)
# Gather all the results for each column
temp = {key: [] for key in q_res[0].keys()}
# If any column has multiple values, ask for clarification
for result in q_res:
for key in result.keys():
if key != self.domain_key:
temp[key].append(result[key])
next_req = self._gen_next_request(temp, beliefstate)
if next_req:
sys_act.type = SysActionType.Request
sys_act.add_value(next_req)
return sys_act
# Otherwise action type will be inform, so return an empty inform (to be filled in later)
sys_act.type = SysActionType.InformByName
return sys_act
def _gen_next_request(self, temp: Dict[str, List[str]], belief_state: BeliefState):
"""
Calculates which slot to request next based asking for non-binary slotes first and then
based on which binary slots provide the biggest reduction in the size of db results
NOTE: If the dataset is large, this is probably not a great idea to calculate each turn
it's relatively simple, but could add up over time
Args:
temp (Dict[str, List[str]]: a dictionary with the keys and values for each result
in the result set
Returns: (str) representing the slot to ask for next (or empty if none)
"""
req_slots = self.domain.get_system_requestable_slots()
# don't other to cacluate statistics for things which have been specified
constraints, dontcare = self._get_constraints(belief_state)
# split out binary slots so we can ask about them second
req_slots = [s for s in req_slots if s not in dontcare and s not in constraints]
bin_slots = [slot for slot in req_slots if len(self.domain.get_possible_values(slot)) == 2]
non_bin_slots = [slot for slot in req_slots if slot not in bin_slots]
# check if there are any differences in values for non-binary slots,
# if a slot has multiple values, ask about that slot
for slot in non_bin_slots:
if len(set(temp[slot])) > 1:
return slot
# Otherwise look to see if there are differnces in binary slots
return self._highest_info_gain(bin_slots, temp)
def _highest_info_gain(self, bin_slots: List[str], temp: Dict[str, List[str]]):
""" Since we don't have lables, we can't properlly calculate entropy, so instead we'll go
for trying to ask after a feature that splits the results in half as evenly as possible
(that way we gain most info regardless of which way the user chooses)
Args:
bin_slots: a list of strings representing system requestable binary slots which
have not yet been specified
temp (Dict[str, List[str]]: a dictionary with the keys and values for each result
in the result set
Returns: (str) representing the slot to ask for next (or empty if none)
"""
diffs = {}
for slot in bin_slots:
val1, val2 = self.domain.get_possible_values(slot)
values_dic = defaultdict(int)
for val in temp[slot]:
values_dic[val] += 1
if val1 in values_dic and val2 in values_dic:
diffs[slot] = abs(values_dic[val1] - values_dic[val2])
# If all slots have the same value, we don't need to request anything, return none
if not diffs:
return ""
sorted_diffs = sorted(diffs.items(), key=lambda kv: kv[1])
return sorted_diffs[0][0]
def _convert_inform(self, q_results: iter,
sys_act: SysAct, beliefstate: BeliefState):
"""Fills in the slots and values for a raw inform so it can be returned as the
next system action.
Args:
method (str): the type of user action
('byprimarykey', 'byconstraints', 'byalternatives')
q_results (list): Results of SQL database query
sys_act (SysAct): the act to be modified
belief_state(dict): contains info on what columns were queried
--LV
"""
"""beliefstate["requests"] or """
if self.domain.get_primary_key() in beliefstate['informs']:
self._convert_inform_by_primkey(q_results, sys_act, beliefstate)
elif UserActionType.RequestAlternatives in beliefstate['user_acts']:
self._convert_inform_by_alternatives(sys_act, q_results, beliefstate)
else:
self._convert_inform_by_constraints(q_results, sys_act, beliefstate)
def _convert_inform_by_primkey(self, q_results: iter,
sys_act: SysAct, belief_state: BeliefState):
"""
Helper function that adds the values for slots to a SysAct object when the system
is answering a request for information about an entity from the user
Args:
q_results (iterable): list of query results from the database
sys_act (SysAct): current raw sys_act to be filled in
belief_state (BeliefState)
"""
sys_act.type = SysActionType.InformByName
if q_results:
result = q_results[0] # currently return just the first result
keys = list(result.keys()) # should represent all user specified constraints
# add slots + values (where available) to the sys_act
for k in keys:
res = result[k] if result[k] else 'not available'
sys_act.add_value(k, res)
# Name might not be a constraint in request queries, so add it
if self.domain_key not in keys:
name = self._get_name(belief_state)
sys_act.add_value(self.domain_key, name)
# Add default Inform slots
for slot in self.domain.get_default_inform_slots():
if slot not in sys_act.slot_values:
sys_act.add_value(slot, result[slot])
else:
sys_act.add_value(self.domain_key, 'none')
def _convert_inform_by_alternatives(
self, sys_act: SysAct, q_res: iter, belief_state: BeliefState):
"""
Helper Function, scrolls through the list of alternative entities which match the
user's specified constraints and uses the next item in the list to fill in the raw
inform act.
When the end of the list is reached, currently continues to give last item in the list
as a suggestion
Args:
sys_act (SysAct): the raw inform to be filled in
belief_state (BeliefState): current system belief state ()
"""
if q_res and not self.current_suggestions:
self.current_suggestions = []
self.s_index = -1
for result in q_res:
self.current_suggestions.append(result)
self.s_index += 1
# here we should scroll through possible offers presenting one each turn the user asks
# for alternatives
if self.s_index <= len(self.current_suggestions) - 1:
# the first time we inform, we should inform by name, so we use the right template
if self.s_index == 0:
sys_act.type = SysActionType.InformByName
else:
sys_act.type = SysActionType.InformByAlternatives
result = self.current_suggestions[self.s_index]
# Inform by alternatives according to our current templates is
# just a normal inform apparently --LV
sys_act.add_value(self.domain_key, result[self.domain_key])
else:
sys_act.type = SysActionType.InformByAlternatives
# default to last suggestion in the list
self.s_index = len(self.current_suggestions) - 1
sys_act.add_value(self.domain.get_primary_key(), 'none')
# in addition to the name, add the constraints the user has specified, so they know the
# offer is relevant to them
constraints, dontcare = self._get_constraints(belief_state)
for c in constraints:
sys_act.add_value(c, constraints[c])
def _convert_inform_by_constraints(self, q_results: iter,
sys_act: SysAct, belief_state: BeliefState):
"""
Helper function for filling in slots and values of a raw inform act when the system is
ready to make the user an offer
Args:
q_results (iter): the results from the databse query
sys_act (SysAct): the raw infor act to be filled in
belief_state (BeliefState): the current system beliefs
"""
if list(q_results):
self.current_suggestions = []
self.s_index = 0
for result in q_results:
self.current_suggestions.append(result)
result = self.current_suggestions[0]
sys_act.add_value(self.domain_key, result[self.domain_key])
# Add default Inform slots
for slot in self.domain.get_default_inform_slots():
if slot not in sys_act.slot_values:
sys_act.add_value(slot, result[slot])
else:
sys_act.add_value(self.domain_key, 'none')
sys_act.type = SysActionType.InformByName
constraints, dontcare = self._get_constraints(belief_state)
for c in constraints:
# Using constraints here rather than results to deal with empty
# results sets (eg. user requests something impossible) --LV
sys_act.add_value(c, constraints[c])
if self.current_suggestions:
for slot in belief_state['requests']:
if slot not in sys_act.slot_values:
sys_act.add_value(slot, self.current_suggestions[0][slot])
__init__(self, domain, logger=<DiasysLogger adviser (NOTSET)>)
special
¶
Initializes the policy
Source code in adviser/services/policy/policy_api.py
def __init__(self, domain: LookupDomain, logger: DiasysLogger = DiasysLogger()):
"""
Initializes the policy
Arguments:
domain {domain.lookupdomain.LookupDomain} -- Domain
"""
self.first_turn = True
Service.__init__(self, domain=domain)
self.last_action = None
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
self.domain_key = domain.get_primary_key()
self.logger = logger
choose_sys_act(self, *args, **kwargs)
¶
Source code in adviser/services/policy/policy_api.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
dialog_start(self)
¶
This function is called before the first message to a new dialog is published. You should overwrite this function to set/reset dialog-level variables.
policy_handcrafted
¶
HandcraftedPolicy (Service)
¶
Base class for handcrafted policies.
Provides a simple rule-based policy. Can be used for any domain where a user is trying to find an entity (eg. a course from a module handbook) from a database by providing constraints (eg. semester the course is offered) or where a user is trying to find out additional information about a named entity.
Output is a system action such as:
* inform
: provides information on an entity
* request
: request more information from the user
* bye
: issue parting message and end dialog
In order to create your own policy, you can inherit from this class.
Make sure to overwrite the choose_sys_act
-method with whatever additionally
rules/functionality required.
Source code in adviser/services/policy/policy_handcrafted.py
class HandcraftedPolicy(Service):
""" Base class for handcrafted policies.
Provides a simple rule-based policy. Can be used for any domain where a user is
trying to find an entity (eg. a course from a module handbook) from a database
by providing constraints (eg. semester the course is offered) or where a user is
trying to find out additional information about a named entity.
Output is a system action such as:
* `inform`: provides information on an entity
* `request`: request more information from the user
* `bye`: issue parting message and end dialog
In order to create your own policy, you can inherit from this class.
Make sure to overwrite the `choose_sys_act`-method with whatever additionally
rules/functionality required.
"""
def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
max_turns: int = 25):
"""
Initializes the policy
Arguments:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
"""
self.first_turn = True
Service.__init__(self, domain=domain)
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
self.domain_key = domain.get_primary_key()
self.logger = logger
self.max_turns = max_turns
def dialog_start(self):
"""
resets the policy after each dialog
"""
self.turns = 0
self.first_turn = True
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
@PublishSubscribe(sub_topics=["beliefstate"], pub_topics=["sys_act", "sys_state"])
def choose_sys_act(self, beliefstate: BeliefState) \
-> dict(sys_act=SysAct):
"""
Responsible for walking the policy through a single turn. Uses the current user
action and system belief state to determine what the next system action should be.
To implement an alternate policy, this method may need to be overwritten
Args:
belief_state (BeliefState): a BeliefState obejct representing current system
knowledge
Returns:
(dict): a dictionary with the key "sys_act" and the value that of the systems next
action
"""
self.turns += 1
# do nothing on the first turn --LV
sys_state = {}
if self.first_turn and not beliefstate['user_acts']:
self.first_turn = False
sys_act = SysAct()
sys_act.type = SysActionType.Welcome
sys_state["last_act"] = sys_act
return {'sys_act': sys_act, "sys_state": sys_state}
# Handles case where it was the first turn, but there are user acts
elif self.first_turn:
self.first_turn = False
if self.turns >= self.max_turns:
sys_act = SysAct()
sys_act.type = SysActionType.Bye
sys_state["last_act"] = sys_act
return {'sys_act': sys_act, "sys_state": sys_state}
# removes hello and thanks if there are also domain specific actions
self._remove_gen_actions(beliefstate)
if UserActionType.Bad in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.Bad
# if the action is 'bye' tell system to end dialog
elif UserActionType.Bye in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.Bye
# if user only says thanks, ask if they want anything else
elif UserActionType.Thanks in beliefstate["user_acts"]:
sys_act = SysAct()
sys_act.type = SysActionType.RequestMore
# If user only says hello, request a random slot to move dialog along
elif UserActionType.Hello in beliefstate["user_acts"] or UserActionType.SelectDomain in beliefstate["user_acts"]:
# as long as there are open slots, choose one randomly
if self._get_open_slot(beliefstate):
sys_act = SysAct()
sys_act.type = SysActionType.Request
slot = self._get_open_slot(beliefstate)
sys_act.add_value(slot)
# If there are no more open slots, ask the user if you can help with anything else since
# this can only happen in the case an offer has already been made --LV
else:
sys_act = SysAct()
sys_act.type = SysActionType.RequestMore
# If we switch to the domain, start a new dialog
if UserActionType.SelectDomain in beliefstate["user_acts"]:
self.dialog_start()
self.first_turn = False
# handle domain specific actions
else:
sys_act, sys_state = self._next_action(beliefstate)
if self.logger:
self.logger.dialog_turn("System Action: " + str(sys_act))
if "last_act" not in sys_state:
sys_state["last_act"] = sys_act
return {'sys_act': sys_act, "sys_state": sys_state}
def _remove_gen_actions(self, beliefstate: BeliefState):
"""
Helper function to read through user action list and if necessary
delete filler actions (eg. Hello, thanks) when there are other non-filler
(eg. Inform, Request) actions from the user. Stores list of relevant actions
as a class variable
Args:
beliefstate (BeliefState): BeliefState object - includes list of all
current UserActionTypes
"""
act_types_lst = beliefstate["user_acts"]
# These are filler actions, so if there are other non-filler acions, remove them from
# the list of action types
while len(act_types_lst) > 1:
if UserActionType.Thanks in act_types_lst:
act_types_lst.remove(UserActionType.Thanks)
elif UserActionType.Bad in act_types_lst:
act_types_lst.remove(UserActionType.Bad)
elif UserActionType.Hello in act_types_lst:
act_types_lst.remove(UserActionType.Hello)
else:
break
def _query_db(self, beliefstate: BeliefState):
"""Based on the constraints specified, uses the domain to generate the appropriate type
of query for the database
Args:
beliefstate (BeliefState): BeliefState object; contains all given user constraints to date
Returns:
iterable: representing the results of the database lookup
--LV
"""
# determine if an entity has already been suggested or was mentioned by the user
name = self._get_name(beliefstate)
# if yes and the user is asking for info about a specific entity, generate a query to get
# that info for the slots they have specified
if name and beliefstate['requests']:
requested_slots = beliefstate['requests']
return self.domain.find_info_about_entity(name, requested_slots)
# otherwise, issue a query to find all entities which satisfy the constraints the user
# has given so far
else:
constraints, _ = self._get_constraints(beliefstate)
return self.domain.find_entities(constraints)
def _get_name(self, beliefstate: BeliefState):
"""Finds if an entity has been suggested by the system (in the form of an offer candidate)
or by the user (in the form of an InformByName act). If so returns the identifier for
it, otherwise returns None
Args:
beliefstate (BeliefState): BeliefState object, contains all known user informs
Return:
(str): Returns a string representing the current entity name
-LV
"""
name = None
prim_key = self.domain.get_primary_key()
if prim_key in beliefstate['informs']:
possible_names = beliefstate['informs'][prim_key]
name = sorted(possible_names.items(), key=lambda kv: kv[1], reverse=True)[0][0]
# if the user is tyring to query by name
else:
if self.s_index < len(self.current_suggestions):
current_suggestion = self.current_suggestions[self.s_index]
if current_suggestion:
name = current_suggestion[self.domain_key]
return name
def _get_constraints(self, beliefstate: BeliefState):
"""Reads the belief state and extracts any user specified constraints and any constraints
the user indicated they don't care about, so the system knows not to ask about them
Args:
beliefstate (BeliefState): BeliefState object; contains all user constraints to date
Return:
(tuple): dict of user requested slot names and their values and list of slots the user
doesn't care about
--LV
"""
slots = {}
# parts of the belief state which don't contain constraints
dontcare = [slot for slot in beliefstate['informs'] if "dontcare" in beliefstate["informs"][slot]]
informs = beliefstate["informs"]
slots = {}
# TODO: consider threshold of belief for adding a value? --LV
for slot in informs:
if slot not in dontcare:
for value in informs[slot]:
slots[slot] = value
return slots, dontcare
def _get_open_slot(self, beliefstate: BeliefState):
"""For a hello statement we need to be able to figure out what slots the user has not yet
specified constraint for, this method returns one of those at random
Args:
beliefstate (BeliefState): BeliefState object; contains all user constraints to date
Returns:
(str): a string representing a category the system might want more info on. If all
system requestables have been filled, return none
"""
filled_slots, _ = self._get_constraints(beliefstate)
requestable_slots = self.domain.get_system_requestable_slots()
for slot in requestable_slots:
if slot not in filled_slots:
return slot
return None
def _next_action(self, beliefstate: BeliefState):
"""Determines the next system action based on the current belief state and
previous action.
When implementing a new type of policy, this method MUST be rewritten
Args:
beliefstate (BeliefState): BeliefState object; contains all user constraints to date
of each possible state
Return:
(SysAct): the next system action
--LV
"""
sys_state = {}
# Assuming this happens only because domain is not actually active --LV
if UserActionType.Bad in beliefstate['user_acts'] or beliefstate['requests'] \
and not self._get_name(beliefstate):
sys_act = SysAct()
sys_act.type = SysActionType.Bad
return sys_act, {'last_act': sys_act}
elif UserActionType.RequestAlternatives in beliefstate['user_acts'] \
and not self._get_constraints(beliefstate)[0]:
sys_act = SysAct()
sys_act.type = SysActionType.Bad
return sys_act, {'last_act': sys_act}
elif self.domain.get_primary_key() in beliefstate['informs'] \
and not beliefstate['requests']:
sys_act = SysAct()
sys_act.type = SysActionType.InformByName
sys_act.add_value(self.domain.get_primary_key(), self._get_name(beliefstate))
return sys_act, {'last_act': sys_act}
# Otherwise we need to query the db to determine next action
results = self._query_db(beliefstate)
sys_act = self._raw_action(results, beliefstate)
# requests are fairly easy, if it's a request, return it directly
if sys_act.type == SysActionType.Request:
if len(list(sys_act.slot_values.keys())) > 0:
sys_state['lastRequestSlot'] = list(sys_act.slot_values.keys())[0]
# otherwise we need to convert a raw inform into a one with proper slots and values
elif sys_act.type == SysActionType.InformByName:
self._convert_inform(results, sys_act, beliefstate)
# update belief state to reflect the offer we just made
values = sys_act.get_values(self.domain.get_primary_key())
if values:
# belief_state['system']['lastInformedPrimKeyVal'] = values[0]
sys_state['lastInformedPrimKeyVal'] = values[0]
else:
sys_act.add_value(self.domain.get_primary_key(), 'none')
sys_state['last_act'] = sys_act
return (sys_act, sys_state)
def _raw_action(self, q_res: iter, beliefstate: BeliefState) -> SysAct:
"""Based on the output of the db query and the method, choose
whether next action should be request or inform
Args:
q_res (list): rows (list of dicts) returned by the issued sqlite3 query
beliefstate (BeliefState): contains all UserActionTypes for the current turn
Returns:
(SysAct): SysAct object of appropriate type
--LV
"""
sys_act = SysAct()
# if there is more than one result
if len(q_res) > 1 and not beliefstate['requests']:
constraints, dontcare = self._get_constraints(beliefstate)
# Gather all the results for each column
temp = {key: [] for key in q_res[0].keys()}
# If any column has multiple values, ask for clarification
for result in q_res:
for key in result.keys():
if key != self.domain_key:
temp[key].append(result[key])
next_req = self._gen_next_request(temp, beliefstate)
if next_req:
sys_act.type = SysActionType.Request
sys_act.add_value(next_req)
return sys_act
# Otherwise action type will be inform, so return an empty inform (to be filled in later)
sys_act.type = SysActionType.InformByName
return sys_act
def _gen_next_request(self, temp: Dict[str, List[str]], belief_state: BeliefState):
"""
Calculates which slot to request next based asking for non-binary slotes first and then
based on which binary slots provide the biggest reduction in the size of db results
NOTE: If the dataset is large, this is probably not a great idea to calculate each turn
it's relatively simple, but could add up over time
Args:
temp (Dict[str, List[str]]: a dictionary with the keys and values for each result
in the result set
Returns: (str) representing the slot to ask for next (or empty if none)
"""
req_slots = self.domain.get_system_requestable_slots()
# don't other to cacluate statistics for things which have been specified
constraints, dontcare = self._get_constraints(belief_state)
# split out binary slots so we can ask about them second
req_slots = [s for s in req_slots if s not in dontcare and s not in constraints]
bin_slots = [slot for slot in req_slots if len(self.domain.get_possible_values(slot)) == 2]
non_bin_slots = [slot for slot in req_slots if slot not in bin_slots]
# check if there are any differences in values for non-binary slots,
# if a slot has multiple values, ask about that slot
for slot in non_bin_slots:
if len(set(temp[slot])) > 1:
return slot
# Otherwise look to see if there are differnces in binary slots
return self._highest_info_gain(bin_slots, temp)
def _highest_info_gain(self, bin_slots: List[str], temp: Dict[str, List[str]]):
""" Since we don't have lables, we can't properlly calculate entropy, so instead we'll go
for trying to ask after a feature that splits the results in half as evenly as possible
(that way we gain most info regardless of which way the user chooses)
Args:
bin_slots: a list of strings representing system requestable binary slots which
have not yet been specified
temp (Dict[str, List[str]]: a dictionary with the keys and values for each result
in the result set
Returns: (str) representing the slot to ask for next (or empty if none)
"""
diffs = {}
for slot in bin_slots:
val1, val2 = self.domain.get_possible_values(slot)
values_dic = defaultdict(int)
for val in temp[slot]:
values_dic[val] += 1
if val1 in values_dic and val2 in values_dic:
diffs[slot] = abs(values_dic[val1] - values_dic[val2])
# If all slots have the same value, we don't need to request anything, return none
if not diffs:
return ""
sorted_diffs = sorted(diffs.items(), key=lambda kv: kv[1])
return sorted_diffs[0][0]
def _convert_inform(self, q_results: iter,
sys_act: SysAct, beliefstate: BeliefState):
"""Fills in the slots and values for a raw inform so it can be returned as the
next system action.
Args:
q_results (list): Results of SQL database query
sys_act (SysAct): the act to be modified
beliefstate(BeliefState): BeliefState object; contains all user constraints to date and
the UserActionTypes for the current turn
--LV
"""
if beliefstate["requests"] or self.domain.get_primary_key() in beliefstate['informs']:
self._convert_inform_by_primkey(q_results, sys_act, beliefstate)
elif UserActionType.RequestAlternatives in beliefstate['user_acts']:
self._convert_inform_by_alternatives(sys_act, q_results, beliefstate)
else:
self._convert_inform_by_constraints(q_results, sys_act, beliefstate)
def _convert_inform_by_primkey(self, q_results: iter,
sys_act: SysAct, beliefstate: BeliefState):
"""
Helper function that adds the values for slots to a SysAct object when the system
is answering a request for information about an entity from the user
Args:
q_results (iterable): list of query results from the database
sys_act (SysAct): current raw sys_act to be filled in
beliefstate (BeliefState): BeliefState object; contains all user informs to date
"""
sys_act.type = SysActionType.InformByName
if q_results:
result = q_results[0] # currently return just the first result
keys = list(result.keys())[:4] # should represent all user specified constraints
# add slots + values (where available) to the sys_act
for k in keys:
res = result[k] if result[k] else 'not available'
sys_act.add_value(k, res)
# Name might not be a constraint in request queries, so add it
if self.domain_key not in keys:
name = self._get_name(beliefstate)
sys_act.add_value(self.domain_key, name)
else:
sys_act.add_value(self.domain_key, 'none')
def _convert_inform_by_alternatives(
self, sys_act: SysAct, q_res: iter, beliefstate: BeliefState):
"""
Helper Function, scrolls through the list of alternative entities which match the
user's specified constraints and uses the next item in the list to fill in the raw
inform act.
When the end of the list is reached, currently continues to give last item in the list
as a suggestion
Args:
sys_act (SysAct): the raw inform to be filled in
beliefstate (BeliefState): current system belief state
"""
if q_res and not self.current_suggestions:
self.current_suggestions = []
self.s_index = -1
for result in q_res:
self.current_suggestions.append(result)
self.s_index += 1
# here we should scroll through possible offers presenting one each turn the user asks
# for alternatives
if self.s_index <= len(self.current_suggestions) - 1:
# the first time we inform, we should inform by name, so we use the right template
if self.s_index == 0:
sys_act.type = SysActionType.InformByName
else:
sys_act.type = SysActionType.InformByAlternatives
result = self.current_suggestions[self.s_index]
# Inform by alternatives according to our current templates is
# just a normal inform apparently --LV
sys_act.add_value(self.domain_key, result[self.domain_key])
else:
sys_act.type = SysActionType.InformByAlternatives
# default to last suggestion in the list
self.s_index = len(self.current_suggestions) - 1
sys_act.add_value(self.domain.get_primary_key(), 'none')
# in addition to the name, add the constraints the user has specified, so they know the
# offer is relevant to them
constraints, dontcare = self._get_constraints(beliefstate)
for c in constraints:
sys_act.add_value(c, constraints[c])
def _convert_inform_by_constraints(self, q_results: iter,
sys_act: SysAct, beliefstate: BeliefState):
"""
Helper function for filling in slots and values of a raw inform act when the system is
ready to make the user an offer
Args:
q_results (iter): the results from the databse query
sys_act (SysAct): the raw infor act to be filled in
beliefstate (BeliefState): the current system beliefs
"""
# TODO: Do we want some way to allow users to scroll through
# result set other than to type 'alternatives'? --LV
if q_results:
self.current_suggestions = []
self.s_index = 0
for result in q_results:
self.current_suggestions.append(result)
result = self.current_suggestions[0]
sys_act.add_value(self.domain_key, result[self.domain_key])
else:
sys_act.add_value(self.domain_key, 'none')
sys_act.type = SysActionType.InformByName
constraints, dontcare = self._get_constraints(beliefstate)
for c in constraints:
# Using constraints here rather than results to deal with empty
# results sets (eg. user requests something impossible) --LV
sys_act.add_value(c, constraints[c])
__init__(self, domain, logger=<DiasysLogger adviser (NOTSET)>, max_turns=25)
special
¶
Initializes the policy
Source code in adviser/services/policy/policy_handcrafted.py
def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
max_turns: int = 25):
"""
Initializes the policy
Arguments:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
"""
self.first_turn = True
Service.__init__(self, domain=domain)
self.current_suggestions = [] # list of current suggestions
self.s_index = 0 # the index in current suggestions for the current system reccomendation
self.domain_key = domain.get_primary_key()
self.logger = logger
self.max_turns = max_turns
choose_sys_act(self, *args, **kwargs)
¶
Source code in adviser/services/policy/policy_handcrafted.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
dialog_start(self)
¶
resets the policy after each dialog
Source code in adviser/services/policy/policy_handcrafted.py
rl
special
¶
dqn
¶
DQN (Module)
¶
Simple Deep Q-Network
Source code in adviser/services/policy/rl/dqn.py
class DQN(nn.Module):
""" Simple Deep Q-Network """
def __init__(self, state_dim: int, action_dim: int, hidden_layer_sizes: List[int] = [300, 300],
dropout_rate: float = 0.0):
""" Initialize a DQN Network with an arbitrary amount of linear hidden
layers """
super(DQN, self).__init__()
print("Architecture: DQN")
self.dropout_rate = dropout_rate
# create layers
self.layers = nn.ModuleList()
current_input_dim = state_dim
for layer_size in hidden_layer_sizes:
self.layers.append(nn.Linear(current_input_dim, layer_size))
self.layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.layers.append(nn.Dropout(p=dropout_rate))
current_input_dim = layer_size
# output layer
self.layers.append(nn.Linear(current_input_dim, action_dim))
def forward(self, state_batch: torch.FloatTensor):
""" Forward pass: calculate Q(state) for all actions
Args:
state_batch (torch.FloatTensor): tensor of size batch_size x state_dim
Returns:
output: tensor of size batch_size x action_dim
"""
output = state_batch
for layer in self.layers:
output = layer(output)
return output
__init__(self, state_dim, action_dim, hidden_layer_sizes=[300, 300], dropout_rate=0.0)
special
¶
Initialize a DQN Network with an arbitrary amount of linear hidden layers
Source code in adviser/services/policy/rl/dqn.py
def __init__(self, state_dim: int, action_dim: int, hidden_layer_sizes: List[int] = [300, 300],
dropout_rate: float = 0.0):
""" Initialize a DQN Network with an arbitrary amount of linear hidden
layers """
super(DQN, self).__init__()
print("Architecture: DQN")
self.dropout_rate = dropout_rate
# create layers
self.layers = nn.ModuleList()
current_input_dim = state_dim
for layer_size in hidden_layer_sizes:
self.layers.append(nn.Linear(current_input_dim, layer_size))
self.layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.layers.append(nn.Dropout(p=dropout_rate))
current_input_dim = layer_size
# output layer
self.layers.append(nn.Linear(current_input_dim, action_dim))
forward(self, state_batch)
¶
Forward pass: calculate Q(state) for all actions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_batch |
torch.FloatTensor |
tensor of size batch_size x state_dim |
required |
Returns:
Type | Description |
---|---|
output |
tensor of size batch_size x action_dim |
Source code in adviser/services/policy/rl/dqn.py
def forward(self, state_batch: torch.FloatTensor):
""" Forward pass: calculate Q(state) for all actions
Args:
state_batch (torch.FloatTensor): tensor of size batch_size x state_dim
Returns:
output: tensor of size batch_size x action_dim
"""
output = state_batch
for layer in self.layers:
output = layer(output)
return output
DuelingDQN (Module)
¶
Dueling DQN network architecture
Splits network into value- and advantage stream (V(s) and A(s,a)), recombined in final layer to form Q-value again: Q(s,a) = V(s) + A(s,a).
Source code in adviser/services/policy/rl/dqn.py
class DuelingDQN(nn.Module):
""" Dueling DQN network architecture
Splits network into value- and advantage stream (V(s) and A(s,a)),
recombined in final layer to form Q-value again:
Q(s,a) = V(s) + A(s,a).
"""
def __init__(self, state_dim: int, action_dim: int,
shared_layer_sizes: List[int] = [128], value_layer_sizes: List[int] = [128],
advantage_layer_sizes: List[int] = [128], dropout_rate: float = 0.0):
super(DuelingDQN, self).__init__()
print("ARCHITECTURE: Dueling")
self.dropout_rate = dropout_rate
# configure layers
self.shared_layers = nn.ModuleList()
self.value_layers = nn.ModuleList()
self.advantage_layers = nn.ModuleList()
# shared layer: state_dim -> shared_layer_sizes[-1]
shared_layer_dim = state_dim
for layer_size in shared_layer_sizes:
self.shared_layers.append(nn.Linear(shared_layer_dim, layer_size))
self.shared_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.shared_layers.append(nn.Dropout(p=dropout_rate))
shared_layer_dim = layer_size
# value layer: shared_layer_sizes[-1] -> 1
value_layer_dim = shared_layer_dim
for layer_size in value_layer_sizes:
self.value_layers.append(nn.Linear(value_layer_dim, layer_size))
self.value_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.value_layers.append(nn.Dropout(p=dropout_rate))
value_layer_dim = layer_size
self.value_layers.append(nn.Linear(value_layer_dim, 1))
# advantage layer: shared_layer_sizes[-1] -> actions
advantage_layer_dim = shared_layer_dim
for layer_size in advantage_layer_sizes:
self.advantage_layers.append(nn.Linear(advantage_layer_dim, layer_size))
self.advantage_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.advantage_layers.append(nn.Dropout(p=dropout_rate))
advantage_layer_dim = layer_size
self.advantage_layers.append(nn.Linear(advantage_layer_dim, action_dim))
def forward(self, state_batch: torch.FloatTensor):
""" Forward pass: calculate Q(state) for all actions
Args:
input (torch.FloatTensor): tensor of size batch_size x state_dim
Returns:
tensor of size batch_size x action_dim
"""
shared_output = state_batch
# shared layer representation
for layer in self.shared_layers:
shared_output = layer(shared_output)
# value stream
value_stream = shared_output
for layer in self.value_layers:
value_stream = layer(value_stream)
# advantage stream
advantage_stream = shared_output
for layer in self.advantage_layers:
advantage_stream = layer(advantage_stream)
# combine value and advantage streams into Q values
result = value_stream + advantage_stream - advantage_stream.mean()
return result
__init__(self, state_dim, action_dim, shared_layer_sizes=[128], value_layer_sizes=[128], advantage_layer_sizes=[128], dropout_rate=0.0)
special
¶
Source code in adviser/services/policy/rl/dqn.py
def __init__(self, state_dim: int, action_dim: int,
shared_layer_sizes: List[int] = [128], value_layer_sizes: List[int] = [128],
advantage_layer_sizes: List[int] = [128], dropout_rate: float = 0.0):
super(DuelingDQN, self).__init__()
print("ARCHITECTURE: Dueling")
self.dropout_rate = dropout_rate
# configure layers
self.shared_layers = nn.ModuleList()
self.value_layers = nn.ModuleList()
self.advantage_layers = nn.ModuleList()
# shared layer: state_dim -> shared_layer_sizes[-1]
shared_layer_dim = state_dim
for layer_size in shared_layer_sizes:
self.shared_layers.append(nn.Linear(shared_layer_dim, layer_size))
self.shared_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.shared_layers.append(nn.Dropout(p=dropout_rate))
shared_layer_dim = layer_size
# value layer: shared_layer_sizes[-1] -> 1
value_layer_dim = shared_layer_dim
for layer_size in value_layer_sizes:
self.value_layers.append(nn.Linear(value_layer_dim, layer_size))
self.value_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.value_layers.append(nn.Dropout(p=dropout_rate))
value_layer_dim = layer_size
self.value_layers.append(nn.Linear(value_layer_dim, 1))
# advantage layer: shared_layer_sizes[-1] -> actions
advantage_layer_dim = shared_layer_dim
for layer_size in advantage_layer_sizes:
self.advantage_layers.append(nn.Linear(advantage_layer_dim, layer_size))
self.advantage_layers.append(nn.ReLU())
if dropout_rate > 0.0:
self.advantage_layers.append(nn.Dropout(p=dropout_rate))
advantage_layer_dim = layer_size
self.advantage_layers.append(nn.Linear(advantage_layer_dim, action_dim))
forward(self, state_batch)
¶
Forward pass: calculate Q(state) for all actions
Parameters:
Name | Type | Description | Default |
---|---|---|---|
input |
torch.FloatTensor |
tensor of size batch_size x state_dim |
required |
Returns:
Type | Description |
---|---|
tensor of size batch_size x action_dim |
Source code in adviser/services/policy/rl/dqn.py
def forward(self, state_batch: torch.FloatTensor):
""" Forward pass: calculate Q(state) for all actions
Args:
input (torch.FloatTensor): tensor of size batch_size x state_dim
Returns:
tensor of size batch_size x action_dim
"""
shared_output = state_batch
# shared layer representation
for layer in self.shared_layers:
shared_output = layer(shared_output)
# value stream
value_stream = shared_output
for layer in self.value_layers:
value_stream = layer(value_stream)
# advantage stream
advantage_stream = shared_output
for layer in self.advantage_layers:
advantage_stream = layer(advantage_stream)
# combine value and advantage streams into Q values
result = value_stream + advantage_stream - advantage_stream.mean()
return result
NetArchitecture (Enum)
¶
Network architecture for DQN
vanilla: normal MLP dueling: splits network into value- and advantage stream, recombined in final layer
Source code in adviser/services/policy/rl/dqn.py
dqnpolicy
¶
DQNPolicy (RLPolicy, Service)
¶
Source code in adviser/services/policy/rl/dqnpolicy.py
class DQNPolicy(RLPolicy, Service):
def __init__(self, domain: JSONLookupDomain,
architecture: NetArchitecture = NetArchitecture.DUELING,
hidden_layer_sizes: List[int] = [256, 700, 700], # vanilla architecture
shared_layer_sizes: List[int] = [256], value_layer_sizes: List[int] = [300, 300],
advantage_layer_sizes: List[int] = [400, 400], # dueling architecture
lr: float = 0.0001, discount_gamma: float = 0.99,
target_update_rate: int = 3,
replay_buffer_size: int = 8192, batch_size: int = 64,
buffer_cls: Type[Buffer] = NaivePrioritizedBuffer,
eps_start: float = 0.3, eps_end: float = 0.0,
l2_regularisation: float = 0.0, gradient_clipping: float = 5.0,
p_dropout: float = 0.0, training_frequency: int = 2, train_dialogs: int = 1000,
include_confreq: bool = False, logger: DiasysLogger = DiasysLogger(),
max_turns: int = 25,
summary_writer: SummaryWriter = None, device=torch.device('cpu')):
"""
Args:
target_update_rate: if 1, vanilla dqn update
if > 1, double dqn with specified target update
rate
"""
RLPolicy.__init__(
self,
domain, buffer_cls=buffer_cls,
buffer_size=replay_buffer_size, batch_size=batch_size,
discount_gamma=discount_gamma, include_confreq=include_confreq,
logger=logger, max_turns=max_turns, device=device)
Service.__init__(self, domain=domain)
self.writer = summary_writer
self.training_frequency = training_frequency
self.train_dialogs = train_dialogs
self.lr = lr
self.gradient_clipping = gradient_clipping
if gradient_clipping > 0.0 and self.logger:
self.logger.info("Gradient Clipping: " + str(gradient_clipping))
self.target_update_rate = target_update_rate
self.epsilon_start = eps_start
self.epsilon_end = eps_end
# Select network architecture
if architecture == NetArchitecture.VANILLA:
if self.logger:
self.logger.info("Architecture: Vanilla")
self.model = DQN(self.state_dim, self.action_dim,
hidden_layer_sizes=hidden_layer_sizes,
dropout_rate=p_dropout)
else:
if self.logger:
self.logger.info("Architecture: Dueling")
self.model = DuelingDQN(self.state_dim, self.action_dim,
shared_layer_sizes=shared_layer_sizes,
value_layer_sizes=value_layer_sizes,
advantage_layer_sizes=advantage_layer_sizes,
dropout_rate=p_dropout)
# Select network update
self.target_model = None
if target_update_rate > 1:
if self.logger:
self.logger.info("Update: Double")
if architecture == NetArchitecture.VANILLA:
self.target_model = copy.deepcopy(self.model)
elif self.logger:
self.logger.info("Update: Vanilla")
self.optim = optim.Adam(self.model.parameters(), lr=lr, weight_decay=l2_regularisation)
self.loss_fun = nn.SmoothL1Loss(reduction='none')
# self.loss_fun = nn.MSELoss(reduction='none')
self.train_call_count = 0
self.total_train_dialogs = 0
self.epsilon = self.epsilon_start
self.turns = 0
self.cumulative_train_dialogs = -1
def dialog_start(self, dialog_start=False):
self.turns = 0
self.last_sys_act = None
if self.is_training:
self.cumulative_train_dialogs += 1
self.sys_state = {
"lastInformedPrimKeyVal": None,
"lastActionInformNone": False,
"offerHappened": False,
'informedPrimKeyValsSinceNone': []}
def select_action_eps_greedy(self, state_vector: torch.FloatTensor):
""" Epsilon-greedy policy.
Args:
state_vector (torch.FloatTensor): current state (dimension 1 x state_dim)
Returns:
action index for action selected by the agent for the current state
"""
self.eps_scheduler()
# epsilon greedy exploration
if self.is_training and common.random.random() < self.epsilon:
next_action_idx = common.random.randint(0, self.action_dim - 1)
else:
torch.autograd.set_grad_enabled(False)
q_values = self.model(state_vector)
next_action_idx = q_values.squeeze(dim=0).max(dim=0)[1].item()
torch.autograd.set_grad_enabled(True)
return next_action_idx
@PublishSubscribe(sub_topics=["sim_goal"])
def end(self, sim_goal: Goal):
"""
Once the simulation ends, need to store the simulation goal for evaluation
Args:
sim_goal (Goal): the simulation goal, needed for evaluation
"""
self.sim_goal = sim_goal
def dialog_end(self):
"""
clean up needed at the end of a dialog
"""
self.end_dialog(self.sim_goal)
if self.is_training:
self.total_train_dialogs += 1
self.train_batch()
@PublishSubscribe(sub_topics=["beliefstate"], pub_topics=["sys_act", "sys_state"])
def choose_sys_act(self, beliefstate: BeliefState = None) -> dict(sys_act=SysAct):
"""
Determine the next system act based on the given beliefstate
Args:
beliefstate (BeliefState): beliefstate, contains all information the system knows
about the environment (in this case the user)
Returns:
(dict): dictionary where the keys are "sys_act" representing the action chosen by
the policy, and "sys_state" which contains additional informatino which might
be needed by the NLU to disambiguate challenging utterances.
"""
self.num_dialogs = self.cumulative_train_dialogs % self.train_dialogs
if self.cumulative_train_dialogs == 0 and self.target_model is not None:
# start with same weights for target and online net when a new epoch begins
self.target_model.load_state_dict(self.model.state_dict())
self.turns += 1
if self.turns == 1:
# first turn of dialog: say hello & don't record
out_dict = self._expand_hello()
out_dict["sys_state"] = {"last_act": out_dict["sys_act"]}
return out_dict
if self.turns > self.max_turns:
# reached turn limit -> terminate dialog
bye_action = SysAct()
bye_action.type = SysActionType.Bye
self.last_sys_act = bye_action
# self.end_dialog(sim_goal)
if self.logger:
self.logger.dialog_turn("system action > " + str(bye_action))
sys_state = {"last_act": bye_action}
return {'sys_act': bye_action, "sys_state": sys_state}
# intermediate or closing turn
state_vector = self.beliefstate_dict_to_vector(beliefstate)
next_action_idx = -1
# check if user ended dialog
if UserActionType.Bye in beliefstate["user_acts"]:
# user terminated current dialog -> say bye
next_action_idx = self.action_idx(SysActionType.Bye.value)
if next_action_idx == -1:
# dialog continues
next_action_idx = self.select_action_eps_greedy(state_vector)
self.turn_end(beliefstate, state_vector, next_action_idx)
# Update the sys_state
if self.last_sys_act.type in [SysActionType.InformByName, SysActionType.InformByAlternatives]:
values = self.last_sys_act.get_values(self.domain.get_primary_key())
if values:
# belief_state['system']['lastInformedPrimKeyVal'] = values[0]
self.sys_state['lastInformedPrimKeyVal'] = values[0]
elif self.last_sys_act.type == SysActionType.Request:
if len(list(self.last_sys_act.slot_values.keys())) > 0:
self.sys_state['lastRequestSlot'] = list(self.last_sys_act.slot_values.keys())[0]
self.sys_state["last_act"] = self.last_sys_act
return {'sys_act': self.last_sys_act, "sys_state": self.sys_state}
def _forward(self, state: torch.FloatTensor, action: torch.LongTensor):
""" Forward state through DQN, return only Q-values for given actions.
Args:
state (torch.FloatTensor): states (dimension batch x state_dim)
action (torch.LongTensor): actions to select Q-value for (dimension batch x 1)
Returns:
Q-values for selected actions
"""
q_values = self.model(state)
return q_values.gather(1, action)
def _forward_target(self, state: torch.FloatTensor, reward: torch.FloatTensor,
terminal: torch.FloatTensor, gamma: float):
""" Calculate target for TD-loss (DQN)
Args:
state (torch.FloatTensor): states (dimension batch x state_dim)
reward (torch.FloatTensor): rewards (dimension batch x 1)
terminal (torch.LongTensor): indicator {0,1} for terminal states (dimension: batch x 1)
gamma (float): discount factor
Returns:
TD-loss targets
"""
target_q_values = self.model(state)
greedy_actions = target_q_values.max(1)[1].unsqueeze(1)
return reward + (1.0 - terminal) * gamma * target_q_values.gather(1, greedy_actions)
def _forward_target_ddqn(self, state: torch.FloatTensor, reward: torch.FloatTensor,
terminal: torch.FloatTensor, gamma: float):
""" Calculate target for TD-loss (Double DQN - uses online and target network)
Args:
state (torch.FloatTensor): states (dimension batch x state_dim)
reward (torch.FloatTensor): rewards (dimension batch x 1)
terminal (torch.FloatTensor): indicator {0,1} for terminal states (dimension: batch x 1)
gamma (float): discount factor
Returns:
TD-loss targets
"""
greedy_actions = self.model(state).max(1)[1].unsqueeze(1)
target_q_values = self.target_model(state).gather(1, greedy_actions)
target_q_values = reward + (1.0 - terminal) * gamma * target_q_values
return target_q_values
def loss(self, s_batch: torch.FloatTensor, a_batch: torch.LongTensor,
s2_batch: torch.FloatTensor, r_batch: torch.FloatTensor, t_batch: torch.FloatTensor,
gamma: float):
""" Calculate TD-loss for given experience tuples
Args:
s_batch (torch.FloatTensor): states (dimension batch x state_dim)
a_batch (torch.LongTensor): actions (dimension batch x 1)
s2_batch (torch.FloatTensor): next states (dimension: batch x state_dim)
r_batch (torch.FloatTensor): rewards (dimension batch x 1)
t_batch (torch.FloatTensor): indicator {0,1} for terminal states (dimension: batch x 1)
gamma (float): discount factor
Returns:
TD-loss
"""
# forward value
torch.autograd.set_grad_enabled(True)
q_val = self._forward(s_batch, a_batch)
# forward target
torch.autograd.set_grad_enabled(False)
if self.target_model is None:
q_target = self._forward_target(s2_batch, r_batch, t_batch, gamma)
else:
q_target = self._forward_target_ddqn(s2_batch, r_batch, t_batch,
gamma)
torch.autograd.set_grad_enabled(True)
# loss
loss = self.loss_fun(q_val, q_target)
return loss
def train_batch(self):
""" Train on a minibatch drawn from the experience buffer. """
if not self.is_training:
return
if len(self.buffer) >= self.batch_size * 10 and \
self.total_train_dialogs % self.training_frequency == 0:
self.train_call_count += 1
s_batch, a_batch, r_batch, s2_batch, t_batch, indices, importance_weights = \
self.buffer.sample()
self.optim.zero_grad()
torch.autograd.set_grad_enabled(True)
s_batch.requires_grad_()
gamma = torch.tensor([self.discount_gamma] * self.batch_size, dtype=torch.float,
device=self.device).view(self.batch_size, 1)
# calculate loss
loss = self.loss(s_batch, a_batch, s2_batch, r_batch, t_batch, gamma)
if importance_weights is not None:
loss = loss * importance_weights
for i in range(self.batch_size):
# importance weighting
# update priorities
self.buffer.update(i, loss[i].item())
loss = loss.mean()
loss.backward()
# clip gradients
if self.gradient_clipping > 0.0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
# update weights
self.optim.step()
current_loss = loss.item()
torch.autograd.set_grad_enabled(False)
if self.writer is not None:
# plot loss
self.writer.add_scalar('train/loss', current_loss, self.train_call_count)
# plot min/max gradients
max_grad_norm = -1.0
min_grad_norm = 1000000.0
for param in self.model.parameters():
if param.grad is not None:
# TODO decide on norm
current_grad_norm = torch.norm(param.grad, 2)
if current_grad_norm > max_grad_norm:
max_grad_norm = current_grad_norm
if current_grad_norm < min_grad_norm:
min_grad_norm = current_grad_norm
self.writer.add_scalar('train/min_grad', min_grad_norm, self.train_call_count)
self.writer.add_scalar('train/max_grad', max_grad_norm, self.train_call_count)
# update target net
if self.target_model is not None and \
self.train_call_count % self.target_update_rate == 0:
self.target_model.load_state_dict(self.model.state_dict())
def eps_scheduler(self):
""" Linear epsilon decay """
if self.is_training:
self.epsilon = max(0,
self.epsilon_start - (self.epsilon_start - self.epsilon_end)
* float(self.num_dialogs) / float(self.train_dialogs))
if self.writer is not None:
self.writer.add_scalar('train/eps', self.epsilon, self.total_train_dialogs)
def save(self, path: str = os.path.join('models', 'dqn'), version: str = "1.0"):
""" Save model weights
Args:
path (str): path to model folder
version (str): appendix to filename, enables having multiple models for the same domain
(or saving a model after each training epoch)
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
model_file = os.path.join(
path, "rlpolicy_" + self.domain.get_domain_name() + "_" + version + ".pt")
torch.save(self.model, model_file)
def load(self, path: str = os.path.join('models', 'dqn'), version: str = "1.0"):
""" Load model weights
Args:
path (str): path to model folder
version (str): appendix to filename, enables having multiple models for the same domain
(or saving a model after each training epoch)
"""
model_file = os.path.join(
path, "rlpolicy_" + self.domain.get_domain_name() + "_" + version + ".pt")
if not os.path.isfile(model_file):
raise FileNotFoundError("Could not find DQN policy weight file ", model_file)
self.model = torch.load(model_file)
self.logger.info("Loaded DQN weights from file " + model_file)
if self.target_model is not None:
self.target_model.load_state_dict(self.model.state_dict())
def train(self, train=True):
""" Sets module and its subgraph to training mode """
super(DQNPolicy, self).train()
self.is_training = True
self.model.train()
if self.target_model is not None:
self.target_model.train()
def eval(self, eval=True):
""" Sets module and its subgraph to eval mode """
super(DQNPolicy, self).eval()
self.is_training = False
self.model.eval()
if self.target_model is not None:
self.target_model.eval()
__init__(self, domain, architecture=<NetArchitecture.DUELING: 'dueling'>, hidden_layer_sizes=[256, 700, 700], shared_layer_sizes=[256], value_layer_sizes=[300, 300], advantage_layer_sizes=[400, 400], lr=0.0001, discount_gamma=0.99, target_update_rate=3, replay_buffer_size=8192, batch_size=64, buffer_cls=<class 'services.policy.rl.experience_buffer.NaivePrioritizedBuffer'>, eps_start=0.3, eps_end=0.0, l2_regularisation=0.0, gradient_clipping=5.0, p_dropout=0.0, training_frequency=2, train_dialogs=1000, include_confreq=False, logger=<DiasysLogger adviser (NOTSET)>, max_turns=25, summary_writer=None, device=device(type='cpu'))
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_update_rate |
int |
if 1, vanilla dqn update if > 1, double dqn with specified target update rate |
3 |
Source code in adviser/services/policy/rl/dqnpolicy.py
def __init__(self, domain: JSONLookupDomain,
architecture: NetArchitecture = NetArchitecture.DUELING,
hidden_layer_sizes: List[int] = [256, 700, 700], # vanilla architecture
shared_layer_sizes: List[int] = [256], value_layer_sizes: List[int] = [300, 300],
advantage_layer_sizes: List[int] = [400, 400], # dueling architecture
lr: float = 0.0001, discount_gamma: float = 0.99,
target_update_rate: int = 3,
replay_buffer_size: int = 8192, batch_size: int = 64,
buffer_cls: Type[Buffer] = NaivePrioritizedBuffer,
eps_start: float = 0.3, eps_end: float = 0.0,
l2_regularisation: float = 0.0, gradient_clipping: float = 5.0,
p_dropout: float = 0.0, training_frequency: int = 2, train_dialogs: int = 1000,
include_confreq: bool = False, logger: DiasysLogger = DiasysLogger(),
max_turns: int = 25,
summary_writer: SummaryWriter = None, device=torch.device('cpu')):
"""
Args:
target_update_rate: if 1, vanilla dqn update
if > 1, double dqn with specified target update
rate
"""
RLPolicy.__init__(
self,
domain, buffer_cls=buffer_cls,
buffer_size=replay_buffer_size, batch_size=batch_size,
discount_gamma=discount_gamma, include_confreq=include_confreq,
logger=logger, max_turns=max_turns, device=device)
Service.__init__(self, domain=domain)
self.writer = summary_writer
self.training_frequency = training_frequency
self.train_dialogs = train_dialogs
self.lr = lr
self.gradient_clipping = gradient_clipping
if gradient_clipping > 0.0 and self.logger:
self.logger.info("Gradient Clipping: " + str(gradient_clipping))
self.target_update_rate = target_update_rate
self.epsilon_start = eps_start
self.epsilon_end = eps_end
# Select network architecture
if architecture == NetArchitecture.VANILLA:
if self.logger:
self.logger.info("Architecture: Vanilla")
self.model = DQN(self.state_dim, self.action_dim,
hidden_layer_sizes=hidden_layer_sizes,
dropout_rate=p_dropout)
else:
if self.logger:
self.logger.info("Architecture: Dueling")
self.model = DuelingDQN(self.state_dim, self.action_dim,
shared_layer_sizes=shared_layer_sizes,
value_layer_sizes=value_layer_sizes,
advantage_layer_sizes=advantage_layer_sizes,
dropout_rate=p_dropout)
# Select network update
self.target_model = None
if target_update_rate > 1:
if self.logger:
self.logger.info("Update: Double")
if architecture == NetArchitecture.VANILLA:
self.target_model = copy.deepcopy(self.model)
elif self.logger:
self.logger.info("Update: Vanilla")
self.optim = optim.Adam(self.model.parameters(), lr=lr, weight_decay=l2_regularisation)
self.loss_fun = nn.SmoothL1Loss(reduction='none')
# self.loss_fun = nn.MSELoss(reduction='none')
self.train_call_count = 0
self.total_train_dialogs = 0
self.epsilon = self.epsilon_start
self.turns = 0
self.cumulative_train_dialogs = -1
choose_sys_act(self, *args, **kwargs)
¶
Source code in adviser/services/policy/rl/dqnpolicy.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
dialog_end(self)
¶
dialog_start(self, dialog_start=False)
¶
This function is called before the first message to a new dialog is published. You should overwrite this function to set/reset dialog-level variables.
Source code in adviser/services/policy/rl/dqnpolicy.py
end(self, *args, **kwargs)
¶
Source code in adviser/services/policy/rl/dqnpolicy.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
eps_scheduler(self)
¶
Linear epsilon decay
Source code in adviser/services/policy/rl/dqnpolicy.py
def eps_scheduler(self):
""" Linear epsilon decay """
if self.is_training:
self.epsilon = max(0,
self.epsilon_start - (self.epsilon_start - self.epsilon_end)
* float(self.num_dialogs) / float(self.train_dialogs))
if self.writer is not None:
self.writer.add_scalar('train/eps', self.epsilon, self.total_train_dialogs)
eval(self, eval=True)
¶
Sets module and its subgraph to eval mode
load(self, path='models/dqn', version='1.0')
¶
Load model weights
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
path to model folder |
'models/dqn' |
version |
str |
appendix to filename, enables having multiple models for the same domain (or saving a model after each training epoch) |
'1.0' |
Source code in adviser/services/policy/rl/dqnpolicy.py
def load(self, path: str = os.path.join('models', 'dqn'), version: str = "1.0"):
""" Load model weights
Args:
path (str): path to model folder
version (str): appendix to filename, enables having multiple models for the same domain
(or saving a model after each training epoch)
"""
model_file = os.path.join(
path, "rlpolicy_" + self.domain.get_domain_name() + "_" + version + ".pt")
if not os.path.isfile(model_file):
raise FileNotFoundError("Could not find DQN policy weight file ", model_file)
self.model = torch.load(model_file)
self.logger.info("Loaded DQN weights from file " + model_file)
if self.target_model is not None:
self.target_model.load_state_dict(self.model.state_dict())
loss(self, s_batch, a_batch, s2_batch, r_batch, t_batch, gamma)
¶
Calculate TD-loss for given experience tuples
Parameters:
Name | Type | Description | Default |
---|---|---|---|
s_batch |
torch.FloatTensor |
states (dimension batch x state_dim) |
required |
a_batch |
torch.LongTensor |
actions (dimension batch x 1) |
required |
s2_batch |
torch.FloatTensor |
next states (dimension: batch x state_dim) |
required |
r_batch |
torch.FloatTensor |
rewards (dimension batch x 1) |
required |
t_batch |
torch.FloatTensor |
indicator {0,1} for terminal states (dimension: batch x 1) |
required |
gamma |
float |
discount factor |
required |
Returns:
Type | Description |
---|---|
TD-loss |
Source code in adviser/services/policy/rl/dqnpolicy.py
def loss(self, s_batch: torch.FloatTensor, a_batch: torch.LongTensor,
s2_batch: torch.FloatTensor, r_batch: torch.FloatTensor, t_batch: torch.FloatTensor,
gamma: float):
""" Calculate TD-loss for given experience tuples
Args:
s_batch (torch.FloatTensor): states (dimension batch x state_dim)
a_batch (torch.LongTensor): actions (dimension batch x 1)
s2_batch (torch.FloatTensor): next states (dimension: batch x state_dim)
r_batch (torch.FloatTensor): rewards (dimension batch x 1)
t_batch (torch.FloatTensor): indicator {0,1} for terminal states (dimension: batch x 1)
gamma (float): discount factor
Returns:
TD-loss
"""
# forward value
torch.autograd.set_grad_enabled(True)
q_val = self._forward(s_batch, a_batch)
# forward target
torch.autograd.set_grad_enabled(False)
if self.target_model is None:
q_target = self._forward_target(s2_batch, r_batch, t_batch, gamma)
else:
q_target = self._forward_target_ddqn(s2_batch, r_batch, t_batch,
gamma)
torch.autograd.set_grad_enabled(True)
# loss
loss = self.loss_fun(q_val, q_target)
return loss
save(self, path='models/dqn', version='1.0')
¶
Save model weights
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
str |
path to model folder |
'models/dqn' |
version |
str |
appendix to filename, enables having multiple models for the same domain (or saving a model after each training epoch) |
'1.0' |
Source code in adviser/services/policy/rl/dqnpolicy.py
def save(self, path: str = os.path.join('models', 'dqn'), version: str = "1.0"):
""" Save model weights
Args:
path (str): path to model folder
version (str): appendix to filename, enables having multiple models for the same domain
(or saving a model after each training epoch)
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
model_file = os.path.join(
path, "rlpolicy_" + self.domain.get_domain_name() + "_" + version + ".pt")
torch.save(self.model, model_file)
select_action_eps_greedy(self, state_vector)
¶
Epsilon-greedy policy.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_vector |
torch.FloatTensor |
current state (dimension 1 x state_dim) |
required |
Returns:
Type | Description |
---|---|
action index for action selected by the agent for the current state |
Source code in adviser/services/policy/rl/dqnpolicy.py
def select_action_eps_greedy(self, state_vector: torch.FloatTensor):
""" Epsilon-greedy policy.
Args:
state_vector (torch.FloatTensor): current state (dimension 1 x state_dim)
Returns:
action index for action selected by the agent for the current state
"""
self.eps_scheduler()
# epsilon greedy exploration
if self.is_training and common.random.random() < self.epsilon:
next_action_idx = common.random.randint(0, self.action_dim - 1)
else:
torch.autograd.set_grad_enabled(False)
q_values = self.model(state_vector)
next_action_idx = q_values.squeeze(dim=0).max(dim=0)[1].item()
torch.autograd.set_grad_enabled(True)
return next_action_idx
train(self, train=True)
¶
Sets module and its subgraph to training mode
train_batch(self)
¶
Train on a minibatch drawn from the experience buffer.
Source code in adviser/services/policy/rl/dqnpolicy.py
def train_batch(self):
""" Train on a minibatch drawn from the experience buffer. """
if not self.is_training:
return
if len(self.buffer) >= self.batch_size * 10 and \
self.total_train_dialogs % self.training_frequency == 0:
self.train_call_count += 1
s_batch, a_batch, r_batch, s2_batch, t_batch, indices, importance_weights = \
self.buffer.sample()
self.optim.zero_grad()
torch.autograd.set_grad_enabled(True)
s_batch.requires_grad_()
gamma = torch.tensor([self.discount_gamma] * self.batch_size, dtype=torch.float,
device=self.device).view(self.batch_size, 1)
# calculate loss
loss = self.loss(s_batch, a_batch, s2_batch, r_batch, t_batch, gamma)
if importance_weights is not None:
loss = loss * importance_weights
for i in range(self.batch_size):
# importance weighting
# update priorities
self.buffer.update(i, loss[i].item())
loss = loss.mean()
loss.backward()
# clip gradients
if self.gradient_clipping > 0.0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clipping)
# update weights
self.optim.step()
current_loss = loss.item()
torch.autograd.set_grad_enabled(False)
if self.writer is not None:
# plot loss
self.writer.add_scalar('train/loss', current_loss, self.train_call_count)
# plot min/max gradients
max_grad_norm = -1.0
min_grad_norm = 1000000.0
for param in self.model.parameters():
if param.grad is not None:
# TODO decide on norm
current_grad_norm = torch.norm(param.grad, 2)
if current_grad_norm > max_grad_norm:
max_grad_norm = current_grad_norm
if current_grad_norm < min_grad_norm:
min_grad_norm = current_grad_norm
self.writer.add_scalar('train/min_grad', min_grad_norm, self.train_call_count)
self.writer.add_scalar('train/max_grad', max_grad_norm, self.train_call_count)
# update target net
if self.target_model is not None and \
self.train_call_count % self.target_update_rate == 0:
self.target_model.load_state_dict(self.model.state_dict())
experience_buffer
¶
Buffer
¶
Base class for experience replay buffers
Initializes the memory, provides a print function for the memory contents and a method to insert new items into the buffer. Sampling has to be implemented by child classes.
Source code in adviser/services/policy/rl/experience_buffer.py
class Buffer(object):
""" Base class for experience replay buffers
Initializes the memory, provides a print function for the memory contents
and a method to insert new items into the buffer.
Sampling has to be implemented by child classes.
"""
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
discount_gamma: float = 0.99, device=torch.device('cpu')):
assert buffer_size >= batch_size, 'the buffer hast to be larger than the batch size'
self.device = device
self.buffer_size = buffer_size
self.batch_size = batch_size
self.discount_gamma = discount_gamma
# construct memory
self.mem_state = torch.empty(buffer_size, state_dim, dtype=torch.float, device=device)
self.mem_action = torch.empty(buffer_size, 1, dtype=torch.long, device=device)
self.mem_reward = torch.empty(buffer_size, 1, dtype=torch.float, device=device)
self.mem_next_state = torch.empty(buffer_size, state_dim, dtype=torch.float, device=device)
self.mem_terminal = torch.empty(buffer_size, 1, dtype=torch.float, device=device)
self.write_pos = 0
self.last_write_pos = 0
self.buffer_count = 0
self._reset()
def _reset(self):
""" Reset the state between consecutive dialogs
Will be executed automatically after store with terminal=True was
called.
"""
self.last_state = None
self.last_action = None
self.last_reward = None
self.episode_length = 0
def store(self, state: torch.FloatTensor, action: torch.LongTensor, reward: float,
terminal: bool = False):
""" Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s'
automatically).
Args:
state (torch.tensor): this turn's state tensor, or None if terminal = True
action (torch.tensor): this turn's action index (int), or None if terminal = True
reward (torch.tensor): this turn's reward (float)
terminal (bool): indicates whether episode finished (boolean)
"""
reward /= 20.0
if isinstance(self.last_state, type(None)): # and terminal == False:
# first turn of trajectory, don't record since s' is needed
self.last_state = state
self.last_action = action
self.last_reward = reward
return False
else:
if terminal == True:
if self.episode_length > 0:
# update last state's reward and set it to terminal
self.mem_terminal[self.last_write_pos] = float(True)
self.mem_reward[self.last_write_pos] += reward
self._reset()
return False
else:
# in-between turn of trajectory: record
self.mem_state[self.write_pos] = \
self.last_state.clone().detach()
self.mem_action[self.write_pos][0] = self.last_action
self.mem_reward[self.write_pos][0] = self.last_reward
self.mem_next_state[self.write_pos] = state.clone().detach()
self.mem_terminal[self.write_pos] = float(False)
# update last encountered state
self.last_state = state.clone().detach()
self.last_action = action
self.last_reward = reward
# update write index
self.last_write_pos = self.write_pos
self.write_pos = (self.write_pos + 1) % self.buffer_size
if self.buffer_count < self.buffer_size:
self.buffer_count += 1
self.episode_length += 1
return True
def print_contents(self, max_size: int = None):
""" Print contents of the experience replay memory.
Args:
max_size (int): restrict the number of printed items to this number (if not None)
"""
# how many entries to print
print_items = len(self)
if max_size is not None:
print_items = min(print_items, max_size)
print("# REPLAY BUFFER CAPACITY: ", self.buffer_size)
print("# CURRENT ITEM COUNT", len(self))
for i in range(print_items):
print("entry ", i)
print(" action", self.mem_action[i])
print(" reward", self.mem_reward[i])
print(" terminal", self.mem_terminal[i])
print('---------')
# TODO finish printing buffer (state, reward, actions, belief?)
def __len__(self):
""" Returns the number of items currently inside the buffer """
return self.buffer_count
def sample(self):
""" Sample from buffer, has to be implemented by subclasses """
raise NotImplementedError
__init__(self, buffer_size, batch_size, state_dim, discount_gamma=0.99, device=device(type='cpu'))
special
¶
Source code in adviser/services/policy/rl/experience_buffer.py
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
discount_gamma: float = 0.99, device=torch.device('cpu')):
assert buffer_size >= batch_size, 'the buffer hast to be larger than the batch size'
self.device = device
self.buffer_size = buffer_size
self.batch_size = batch_size
self.discount_gamma = discount_gamma
# construct memory
self.mem_state = torch.empty(buffer_size, state_dim, dtype=torch.float, device=device)
self.mem_action = torch.empty(buffer_size, 1, dtype=torch.long, device=device)
self.mem_reward = torch.empty(buffer_size, 1, dtype=torch.float, device=device)
self.mem_next_state = torch.empty(buffer_size, state_dim, dtype=torch.float, device=device)
self.mem_terminal = torch.empty(buffer_size, 1, dtype=torch.float, device=device)
self.write_pos = 0
self.last_write_pos = 0
self.buffer_count = 0
self._reset()
__len__(self)
special
¶
print_contents(self, max_size=None)
¶
Print contents of the experience replay memory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
max_size |
int |
restrict the number of printed items to this number (if not None) |
None |
Source code in adviser/services/policy/rl/experience_buffer.py
def print_contents(self, max_size: int = None):
""" Print contents of the experience replay memory.
Args:
max_size (int): restrict the number of printed items to this number (if not None)
"""
# how many entries to print
print_items = len(self)
if max_size is not None:
print_items = min(print_items, max_size)
print("# REPLAY BUFFER CAPACITY: ", self.buffer_size)
print("# CURRENT ITEM COUNT", len(self))
for i in range(print_items):
print("entry ", i)
print(" action", self.mem_action[i])
print(" reward", self.mem_reward[i])
print(" terminal", self.mem_terminal[i])
print('---------')
# TODO finish printing buffer (state, reward, actions, belief?)
sample(self)
¶
store(self, state, action, reward, terminal=False)
¶
Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s' automatically).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
torch.tensor |
this turn's state tensor, or None if terminal = True |
required |
action |
torch.tensor |
this turn's action index (int), or None if terminal = True |
required |
reward |
torch.tensor |
this turn's reward (float) |
required |
terminal |
bool |
indicates whether episode finished (boolean) |
False |
Source code in adviser/services/policy/rl/experience_buffer.py
def store(self, state: torch.FloatTensor, action: torch.LongTensor, reward: float,
terminal: bool = False):
""" Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s'
automatically).
Args:
state (torch.tensor): this turn's state tensor, or None if terminal = True
action (torch.tensor): this turn's action index (int), or None if terminal = True
reward (torch.tensor): this turn's reward (float)
terminal (bool): indicates whether episode finished (boolean)
"""
reward /= 20.0
if isinstance(self.last_state, type(None)): # and terminal == False:
# first turn of trajectory, don't record since s' is needed
self.last_state = state
self.last_action = action
self.last_reward = reward
return False
else:
if terminal == True:
if self.episode_length > 0:
# update last state's reward and set it to terminal
self.mem_terminal[self.last_write_pos] = float(True)
self.mem_reward[self.last_write_pos] += reward
self._reset()
return False
else:
# in-between turn of trajectory: record
self.mem_state[self.write_pos] = \
self.last_state.clone().detach()
self.mem_action[self.write_pos][0] = self.last_action
self.mem_reward[self.write_pos][0] = self.last_reward
self.mem_next_state[self.write_pos] = state.clone().detach()
self.mem_terminal[self.write_pos] = float(False)
# update last encountered state
self.last_state = state.clone().detach()
self.last_action = action
self.last_reward = reward
# update write index
self.last_write_pos = self.write_pos
self.write_pos = (self.write_pos + 1) % self.buffer_size
if self.buffer_count < self.buffer_size:
self.buffer_count += 1
self.episode_length += 1
return True
NaivePrioritizedBuffer (Buffer)
¶
Prioritized experience replay buffer.
Assigns sampling probabilities dependent on TD-error of the transitions.
Source code in adviser/services/policy/rl/experience_buffer.py
class NaivePrioritizedBuffer(Buffer):
""" Prioritized experience replay buffer.
Assigns sampling probabilities dependent on TD-error of the transitions.
"""
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
sample_last_transition: bool = False,
regularisation: float = 0.00001, exponent: float = 0.6, beta: float = 0.4,
discount_gamma: float = 0.99, device=torch.device('cpu')):
super(NaivePrioritizedBuffer, self).__init__(buffer_size, batch_size, state_dim,
discount_gamma=discount_gamma,
device=device)
print(" REPLAY MEMORY: NAIVE Prioritized")
self.probs = [0.0] * buffer_size
self.regularisation = regularisation
self.exponent = exponent
self.beta = beta
self.max_p = 1.0
self.sample_last_transition = sample_last_transition
# TODO anneal beta over time (see paper prioritized experience replay)
# note: did not make a significant difference with the tested parameters
# - is it worth to re-implement that feature?
def _priority_to_probability(self, priority: float):
""" Convert priority number to probability space (inside [0,1]) """
return (priority + self.regularisation) ** self.exponent
def store(self, state: torch.FloatTensor, action: torch.LongTensor, reward: float,
terminal: bool = False):
""" Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s'
automatically).
Newly added experience tuples will be assigned maximum priority.
Args:
state: this turn's state tensor, or None if terminal = True
action: this turn's action index (int), or None if terminal = True
reward: this turn's reward (float)
terminal: indicates whether episode finished (boolean)
"""
if super(NaivePrioritizedBuffer, self).store(state, action, reward, terminal=terminal):
# create new tree node only if something new was added to the buffers
self.probs[self.last_write_pos] = self._priority_to_probability(self.max_p)
def update(self, idx: int, error: float):
""" Update the priority of transition with index idx """
p = self._priority_to_probability(error)
if p > self.max_p:
self.max_p = p
self.probs[idx] = p
def sample(self):
""" Sample from buffer.
Returns:
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices,
importance weights
"""
batch_size = self.batch_size
batch_write_pos = 0
data_indices = torch.empty(self.batch_size, dtype=torch.long, device=self.device)
probabilities = torch.empty(self.batch_size, dtype=torch.float, device=self.device)
indices = []
self.sample_last_transition = True
p_normed = np.array(self.probs[:self.buffer_count]) / np.linalg.norm(
self.probs[:self.buffer_count], ord=1)
indices = common.numpy.random.choice(list(range(self.buffer_count)), size=self.batch_size,
p=p_normed)
if self.sample_last_transition:
# include last transition (was at tree.write - 1)
# -> see Sutton: A deeper look at experience replay
data_indices[0] = self.last_write_pos
probabilities[0] = self.probs[self.last_write_pos]
# correct size of batch
batch_size = batch_size - 1
batch_write_pos += 1
# TODO add option to sample each segment uniformly
for i in range(batch_write_pos, self.batch_size):
data_indices[i] = int(indices[i])
probabilities[i] = self.probs[data_indices[i]]
# assemble batch from data indices
s_batch = self.mem_state.index_select(0, data_indices)
a_batch = self.mem_action.index_select(0, data_indices)
r_batch = self.mem_reward.index_select(0, data_indices)
t_batch = self.mem_terminal.index_select(0, data_indices)
s2_batch = self.mem_next_state.index_select(0, data_indices)
# calculate importance sampling weights
importance_weights = float(len(self)) * probabilities
importance_weights = importance_weights.pow(-self.beta)
importance_weights = importance_weights / importance_weights.max(dim=0)[0].item()
return s_batch, a_batch, r_batch, s2_batch, t_batch, data_indices, \
importance_weights.view(-1, 1)
__init__(self, buffer_size, batch_size, state_dim, sample_last_transition=False, regularisation=1e-05, exponent=0.6, beta=0.4, discount_gamma=0.99, device=device(type='cpu'))
special
¶
Source code in adviser/services/policy/rl/experience_buffer.py
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
sample_last_transition: bool = False,
regularisation: float = 0.00001, exponent: float = 0.6, beta: float = 0.4,
discount_gamma: float = 0.99, device=torch.device('cpu')):
super(NaivePrioritizedBuffer, self).__init__(buffer_size, batch_size, state_dim,
discount_gamma=discount_gamma,
device=device)
print(" REPLAY MEMORY: NAIVE Prioritized")
self.probs = [0.0] * buffer_size
self.regularisation = regularisation
self.exponent = exponent
self.beta = beta
self.max_p = 1.0
self.sample_last_transition = sample_last_transition
# TODO anneal beta over time (see paper prioritized experience replay)
# note: did not make a significant difference with the tested parameters
# - is it worth to re-implement that feature?
sample(self)
¶
Sample from buffer.
Returns:
Type | Description |
---|---|
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices, importance weights |
Source code in adviser/services/policy/rl/experience_buffer.py
def sample(self):
""" Sample from buffer.
Returns:
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices,
importance weights
"""
batch_size = self.batch_size
batch_write_pos = 0
data_indices = torch.empty(self.batch_size, dtype=torch.long, device=self.device)
probabilities = torch.empty(self.batch_size, dtype=torch.float, device=self.device)
indices = []
self.sample_last_transition = True
p_normed = np.array(self.probs[:self.buffer_count]) / np.linalg.norm(
self.probs[:self.buffer_count], ord=1)
indices = common.numpy.random.choice(list(range(self.buffer_count)), size=self.batch_size,
p=p_normed)
if self.sample_last_transition:
# include last transition (was at tree.write - 1)
# -> see Sutton: A deeper look at experience replay
data_indices[0] = self.last_write_pos
probabilities[0] = self.probs[self.last_write_pos]
# correct size of batch
batch_size = batch_size - 1
batch_write_pos += 1
# TODO add option to sample each segment uniformly
for i in range(batch_write_pos, self.batch_size):
data_indices[i] = int(indices[i])
probabilities[i] = self.probs[data_indices[i]]
# assemble batch from data indices
s_batch = self.mem_state.index_select(0, data_indices)
a_batch = self.mem_action.index_select(0, data_indices)
r_batch = self.mem_reward.index_select(0, data_indices)
t_batch = self.mem_terminal.index_select(0, data_indices)
s2_batch = self.mem_next_state.index_select(0, data_indices)
# calculate importance sampling weights
importance_weights = float(len(self)) * probabilities
importance_weights = importance_weights.pow(-self.beta)
importance_weights = importance_weights / importance_weights.max(dim=0)[0].item()
return s_batch, a_batch, r_batch, s2_batch, t_batch, data_indices, \
importance_weights.view(-1, 1)
store(self, state, action, reward, terminal=False)
¶
Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s' automatically).
Newly added experience tuples will be assigned maximum priority.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state |
FloatTensor |
this turn's state tensor, or None if terminal = True |
required |
action |
LongTensor |
this turn's action index (int), or None if terminal = True |
required |
reward |
float |
this turn's reward (float) |
required |
terminal |
bool |
indicates whether episode finished (boolean) |
False |
Source code in adviser/services/policy/rl/experience_buffer.py
def store(self, state: torch.FloatTensor, action: torch.LongTensor, reward: float,
terminal: bool = False):
""" Store an experience of the form (s,a,r,s',t).
Only needs the current state s (will construct transition to s'
automatically).
Newly added experience tuples will be assigned maximum priority.
Args:
state: this turn's state tensor, or None if terminal = True
action: this turn's action index (int), or None if terminal = True
reward: this turn's reward (float)
terminal: indicates whether episode finished (boolean)
"""
if super(NaivePrioritizedBuffer, self).store(state, action, reward, terminal=terminal):
# create new tree node only if something new was added to the buffers
self.probs[self.last_write_pos] = self._priority_to_probability(self.max_p)
update(self, idx, error)
¶
Update the priority of transition with index idx
UniformBuffer (Buffer)
¶
Experience replay buffer with uniformly random sampling
Source code in adviser/services/policy/rl/experience_buffer.py
class UniformBuffer(Buffer):
""" Experience replay buffer with uniformly random sampling """
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
discount_gamma: float = 0.99, sample_last_transition: bool = True,
device=torch.device('cpu')):
"""
Args:
sample_last_transition (bool): if True, a batch will always include the most recent
transition
(see Sutton: A deeper look at experience replay)
"""
super(UniformBuffer, self).__init__(buffer_size, batch_size, state_dim,
discount_gamma=discount_gamma,
device=device)
print(" REPLAY MEMORY: Uniform")
self.sample_last_transition = sample_last_transition
def sample(self):
""" Sample from buffer.
Returns:
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices,
None
"""
# create random indices
data_indices = []
if self.sample_last_transition:
# include last transition (was at write - 1)
# - see Sutton: A deeper look at experience replay
if self.write_pos - 1 < 0:
# last transition filled the capacity of the buffer
data_indices = [self.buffer_size - 1]
else:
data_indices = [self.write_pos - 1]
data_indices.extend([common.random.randint(0, self.buffer_count - 1) for i in
range(self.batch_size - int(self.sample_last_transition))])
data_indices = torch.tensor(data_indices, dtype=torch.long, device=self.device)
state_batch = self.mem_state.index_select(0, data_indices)
action_batch = self.mem_action.index_select(0, data_indices)
reward_batch = self.mem_reward.index_select(0, data_indices)
next_state_batch = self.mem_next_state.index_select(0, data_indices)
terminal_batch = self.mem_terminal.index_select(0, data_indices)
return state_batch, action_batch, reward_batch, next_state_batch, terminal_batch, \
data_indices, None
__init__(self, buffer_size, batch_size, state_dim, discount_gamma=0.99, sample_last_transition=True, device=device(type='cpu'))
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sample_last_transition |
bool |
if True, a batch will always include the most recent transition (see Sutton: A deeper look at experience replay) |
True |
Source code in adviser/services/policy/rl/experience_buffer.py
def __init__(self, buffer_size: int, batch_size: int, state_dim: int,
discount_gamma: float = 0.99, sample_last_transition: bool = True,
device=torch.device('cpu')):
"""
Args:
sample_last_transition (bool): if True, a batch will always include the most recent
transition
(see Sutton: A deeper look at experience replay)
"""
super(UniformBuffer, self).__init__(buffer_size, batch_size, state_dim,
discount_gamma=discount_gamma,
device=device)
print(" REPLAY MEMORY: Uniform")
self.sample_last_transition = sample_last_transition
sample(self)
¶
Sample from buffer.
Returns:
Type | Description |
---|---|
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices, None |
Source code in adviser/services/policy/rl/experience_buffer.py
def sample(self):
""" Sample from buffer.
Returns:
states, actions, rewards, next states, terminal state indicator {0,1}, buffer indices,
None
"""
# create random indices
data_indices = []
if self.sample_last_transition:
# include last transition (was at write - 1)
# - see Sutton: A deeper look at experience replay
if self.write_pos - 1 < 0:
# last transition filled the capacity of the buffer
data_indices = [self.buffer_size - 1]
else:
data_indices = [self.write_pos - 1]
data_indices.extend([common.random.randint(0, self.buffer_count - 1) for i in
range(self.batch_size - int(self.sample_last_transition))])
data_indices = torch.tensor(data_indices, dtype=torch.long, device=self.device)
state_batch = self.mem_state.index_select(0, data_indices)
action_batch = self.mem_action.index_select(0, data_indices)
reward_batch = self.mem_reward.index_select(0, data_indices)
next_state_batch = self.mem_next_state.index_select(0, data_indices)
terminal_batch = self.mem_terminal.index_select(0, data_indices)
return state_batch, action_batch, reward_batch, next_state_batch, terminal_batch, \
data_indices, None
policy_rl
¶
RLPolicy
¶
Base class for Reinforcement Learning based policies.
Functionality provided includes the setup of state- and action spaces,
conversion of BeliefState
objects into pytorch tensors,
updating the last performed system actions and informed entities,
populating the experience replay buffer,
extraction of most probable user hypothesis and candidate action expansion.
Output of an agent is a candidate action like
inform_food
which is then populated with the most probable slot/value pair from the
beliefstate and database candidates by the expand_system_action
-function to become
inform(slot=food,value=italian)
.
In order to create your own policy, you can inherit from this class.
Make sure to call the turn_end
-function after each system turn and the end_dialog
-function
after each completed dialog.
Source code in adviser/services/policy/rl/policy_rl.py
class RLPolicy(object):
""" Base class for Reinforcement Learning based policies.
Functionality provided includes the setup of state- and action spaces,
conversion of `BeliefState` objects into pytorch tensors,
updating the last performed system actions and informed entities,
populating the experience replay buffer,
extraction of most probable user hypothesis and candidate action expansion.
Output of an agent is a candidate action like
``inform_food``
which is then populated with the most probable slot/value pair from the
beliefstate and database candidates by the `expand_system_action`-function to become
``inform(slot=food,value=italian)``.
In order to create your own policy, you can inherit from this class.
Make sure to call the `turn_end`-function after each system turn and the `end_dialog`-function
after each completed dialog.
"""
def __init__(self, domain: JSONLookupDomain, buffer_cls=UniformBuffer,
buffer_size=6000, batch_size=64, discount_gamma=0.99, max_turns: int = 25,
include_confreq=False, logger: DiasysLogger = DiasysLogger(),
include_select: bool = False, device=torch.device('cpu')):
"""
Creates state- and action spaces, initializes experience replay
buffers.
Arguments:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
Keyword Arguments:
subgraph {[type]} -- [see services.Module] (default: {None})
buffer_cls {services.policy.rl.experience_buffer.Buffer}
-- [Experience replay buffer *class*, **not** an instance - will be
initialized by this constructor!] (default: {UniformBuffer})
buffer_size {int} -- [see services.policy.rl.experience_buffer.
Buffer] (default: {6000})
batch_size {int} -- [see services.policy.rl.experience_buffer.
Buffer] (default: {64})
discount_gamma {float} -- [Discount factor] (default: {0.99})
include_confreq {bool} -- [Use confirm_request actions]
(default: {False})
"""
self.device = device
self.sys_state = {
"lastInformedPrimKeyVal": None,
"lastActionInformNone": False,
"offerHappened": False,
'informedPrimKeyValsSinceNone': []}
self.max_turns = max_turns
self.logger = logger
self.domain = domain
# setup evaluator for training
self.evaluator = ObjectiveReachedEvaluator(domain, logger=logger)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.discount_gamma = discount_gamma
self.writer = None
# get state size
self.state_dim = self.beliefstate_dict_to_vector(
BeliefState(domain)._init_beliefstate()).size(1)
self.logger.info("state space dim: " + str(self.state_dim))
# get system action list
self.actions = ["inform_byname", # TODO rename to 'bykey'
"inform_alternatives",
"reqmore"]
# TODO badaction
for req_slot in self.domain.get_system_requestable_slots():
self.actions.append('request#' + req_slot)
self.actions.append('confirm#' + req_slot)
if include_select:
self.actions.append('select#' + req_slot)
if include_confreq:
for conf_slot in self.domain.get_system_requestable_slots():
if not req_slot == conf_slot:
# skip case where confirm slot = request slot
self.actions.append('confreq#' + conf_slot + '#' +
req_slot)
self.action_dim = len(self.actions)
# don't include closingmsg in learnable actions
self.actions.append('closingmsg')
# self.actions.append("closingmsg")
self.logger.info("action space dim: " + str(self.action_dim))
self.primary_key = self.domain.get_primary_key()
# init replay memory
self.buffer = buffer_cls(buffer_size, batch_size, self.state_dim,
discount_gamma=discount_gamma, device=device)
self.sys_state = {}
self.last_sys_act = None
def action_name(self, action_idx: int):
""" Returns the action name for the specified action index """
return self.actions[action_idx]
def action_idx(self, action_name: str):
""" Returns the action index for the specified action name """
return self.actions.index(action_name)
def beliefstate_dict_to_vector(self, beliefstate: BeliefState):
""" Converts the beliefstate dict to a torch tensor
Args:
beliefstate: dict of belief (with at least beliefs and system keys)
Returns:
belief tensor with dimension 1 x state_dim
"""
belief_vec = []
# add user acts
belief_vec += [1 if act in beliefstate['user_acts'] else 0 for act in UserActionType]
# handle none actions
belief_vec.append(1 if sum(belief_vec) == 0 else 1)
# add informs (including special flag if slot not mentioned)
for slot in sorted(self.domain.get_informable_slots()):
values = self.domain.get_possible_values(slot) + ["dontcare"]
if slot not in beliefstate['informs']:
# add **NONE** value first, then 0.0 for all others
belief_vec.append(1.0)
# also add value for don't care
belief_vec += [0 for i in range(len(values))]
else:
# add **NONE** value first
belief_vec.append(0.0)
bs_slot = beliefstate['informs'][slot]
belief_vec += [bs_slot[value] if value in bs_slot else 0.0 for value in values]
# add requests
for slot in sorted(self.domain.get_requestable_slots()):
if slot in beliefstate['requests']:
belief_vec.append(1.0)
else:
belief_vec.append(0.0)
# append system features
belief_vec.append(float(self.sys_state['lastActionInformNone']))
belief_vec.append(float(self.sys_state['offerHappened']))
candidate_count = beliefstate['num_matches']
# buckets for match count: 0, 1, 2-4, >4
belief_vec.append(float(candidate_count == 0))
belief_vec.append(float(candidate_count == 1))
belief_vec.append(float(2 <= candidate_count <= 4))
belief_vec.append(float(candidate_count > 4))
belief_vec.append(float(beliefstate["discriminable"]))
# convert to torch tensor
return torch.tensor([belief_vec], dtype=torch.float, device=self.device)
def _remove_dontcare_slots(self, slot_value_dict: dict):
""" Returns a new dictionary without the slots set to dontcare """
return {slot: value for slot, value in slot_value_dict.items()
if value != 'dontcare'}
def _get_slotnames_from_actionname(self, action_name: str):
""" Return the slot names of an action of format 'action_slot1_slot2_...' """
return action_name.split('#')[1:]
def _db_results_to_sysact(self, sys_act: SysAct, constraints: dict, db_entity: dict):
""" Adds values of db_entity to constraints of sys_act
(primary key is always added).
Omits values which are not available in database. """
for constraint_slot in constraints:
if constraints[constraint_slot] == 'dontcare' and \
constraint_slot in db_entity and \
db_entity[constraint_slot] != 'not available':
# fill user dontcare with database value
sys_act.add_value(constraint_slot, db_entity[constraint_slot])
else:
if constraint_slot in db_entity:
# fill with database value
sys_act.add_value(constraint_slot,
db_entity[constraint_slot])
elif self.logger:
# slot not in db entity -> create warning
self.logger.warning("Slot " + constraint_slot +
" not found in db entity " +
db_entity[self.primary_key])
# ensure primary key is included
if self.primary_key not in sys_act.slot_values:
sys_act.add_value(self.primary_key, db_entity[self.primary_key])
def _expand_byconstraints(self, beliefstate: BeliefState):
""" Create inform act with an entity from the database, if any matches
could be found for the constraints, otherwise will return an inform
act with primary key=none """
act = SysAct()
act.type = SysActionType.Inform
# get constraints and query db by them
constraints = beliefstate.get_most_probable_inf_beliefs(consider_NONE=True, threshold=0.7,
max_results=1)
db_matches = self.domain.find_entities(constraints, requested_slots=constraints)
if not db_matches:
# no matching entity found -> return inform with primary key=none
# and other constraints
filtered_slot_values = self._remove_dontcare_slots(constraints)
filtered_slots = common.numpy.random.choice(
list(filtered_slot_values.keys()),
min(5, len(filtered_slot_values)), replace=False)
for slot in filtered_slots:
if not slot == 'name':
act.add_value(slot, filtered_slot_values[slot])
act.add_value(self.primary_key, 'none')
else:
# match found -> return its name
# if > 1 match and matches contain last informed entity,
# stick to this
match = [db_match for db_match in db_matches
if db_match[self.primary_key] ==
self.sys_state['lastInformedPrimKeyVal']]
if not match:
# none matches last informed venue -> pick first result
# match = db_matches[0]
match = common.random.choice(db_matches)
else:
assert len(match) == 1
match = match[0]
# fill act with values from db
self._db_results_to_sysact(act, constraints, match)
return act
def _expand_request(self, action_name: str):
""" Expand request_*slot* action """
act = SysAct()
act.type = SysActionType.Request
req_slot = self._get_slotnames_from_actionname(action_name)[0]
act.add_value(req_slot)
return act
def _expand_select(self, action_name: str, beliefstate: BeliefState):
""" Expand select_*slot* action """
act = SysAct()
act.type = SysActionType.Select
sel_slot = self._get_slotnames_from_actionname(action_name)[0]
most_likely_choice = beliefstate.get_most_probable_slot_beliefs(
consider_NONE=False, slot=sel_slot, threshold=0.0, max_results=2)
first_value = most_likely_choice[sel_slot][0]
second_value = most_likely_choice[sel_slot][1]
act.add_value(sel_slot, first_value)
act.add_value(sel_slot, second_value)
return act
def _expand_confirm(self, action_name: str, beliefstate: BeliefState):
""" Expand confirm_*slot* action """
act = SysAct()
act.type = SysActionType.Confirm
conf_slot = self._get_slotnames_from_actionname(action_name)[0]
candidates = beliefstate.get_most_probable_inf_beliefs(consider_NONE=False, threshold=0.0,
max_results=1)
# If the slot isn't in the beliefstate choose a random value from the ontology
conf_value = candidates[conf_slot] if conf_slot in candidates else random.choice(
self.domain.get_possible_values(conf_slot))
act.add_value(conf_slot, conf_value)
return act
def _expand_confreq(self, action_name: str, beliefstate: BeliefState):
""" Expand confreq_*confirmslot*_*requestslot* action """
act = SysAct()
act.type = SysActionType.ConfirmRequest
# first slot name is confirmation, second is request
slots = self._get_slotnames_from_actionname(action_name)
conf_slot = slots[0]
req_slot = slots[1]
# get value that needs confirmation
candidates = beliefstate.get_most_probable_inf_beliefs(consider_NONE=False, threshold=0.0,
max_results=1)
conf_value = candidates[conf_slot]
act.add_value(conf_slot, conf_value)
# add request slot
act.add_value(req_slot)
return act
def _expand_informbyname(self, beliefstate: BeliefState):
""" Expand inform_byname action """
act = SysAct()
act.type = SysActionType.InformByName
# get most probable entity primary key
if self.primary_key in beliefstate['informs']:
primkeyval = sorted(
beliefstate["informs"][self.primary_key].items(),
key=lambda kv: kv[1], reverse=True)[0][0]
else:
# try to use previously informed name instead
primkeyval = self.sys_state['lastInformedPrimKeyVal']
# TODO change behaviour from here, because primkeyval might be "**NONE**" and this might be an entity in the database
# find db entry by primary key
constraints = beliefstate.get_most_probable_inf_beliefs(consider_NONE=True, threshold=0.7,
max_results=1)
db_matches = self.domain.find_entities({**constraints, self.primary_key: primkeyval})
# NOTE usually not needed to give all constraints (shouldn't make a difference)
if not db_matches:
# select random entity if none could be found
primkeyvals = self.domain.get_possible_values(self.primary_key)
primkeyval = common.random.choice(primkeyvals)
db_matches = self.domain.find_entities(
constraints, self.domain.get_requestable_slots())
# use knowledge from current belief state
if not db_matches:
# no results found
filtered_slot_values = self._remove_dontcare_slots(constraints)
for slot in common.numpy.random.choice(
list(filtered_slot_values.keys()),
min(5, len(filtered_slot_values)), replace=False):
act.add_value(slot, filtered_slot_values[slot])
act.add_value(self.primary_key, 'none')
return act
# select random match
db_match = common.random.choice(db_matches)
db_match = self.domain.find_info_about_entity(
db_match[self.primary_key], requested_slots=self.domain.get_requestable_slots())[0]
# get slots requested by user
usr_requests = beliefstate.get_requested_slots()
# remove primary key (to exlude from minimum number) since it is added anyway at the end
if self.primary_key in usr_requests:
usr_requests.remove(self.primary_key)
if usr_requests:
# add user requested values into system act using db result
for req_slot in common.numpy.random.choice(usr_requests, min(4, len(usr_requests)),
replace=False):
if req_slot in db_match:
act.add_value(req_slot, db_match[req_slot])
else:
act.add_value(req_slot, 'none')
else:
constraints = self._remove_dontcare_slots(constraints)
if constraints:
for inform_slot in common.numpy.random.choice(
list(constraints.keys()),
min(4, len(constraints)), replace=False):
value = db_match[inform_slot]
act.add_value(inform_slot, value)
else:
# add random slot and value if no user request was detected
usr_requestable_slots = set(self.domain.get_requestable_slots())
usr_requestable_slots.remove(self.primary_key)
random_slot = common.random.choice(list(usr_requestable_slots))
value = db_match[random_slot]
act.add_value(random_slot, value)
# ensure entity primary key is included
if self.primary_key not in act.slot_values:
act.add_value(self.primary_key, db_match[self.primary_key])
return act
def _expand_informbyalternatives(self, beliefstate: BeliefState):
""" Expand inform_byalternatives action """
act = SysAct()
act.type = SysActionType.InformByAlternatives
# get set of all previously informed primary key values
informedPrimKeyValsSinceNone = set(
self.sys_state['informedPrimKeyValsSinceNone'])
candidates = beliefstate.get_most_probable_inf_beliefs(consider_NONE=True, threshold=0.7,
max_results=1)
filtered_slot_values = self._remove_dontcare_slots(candidates)
# query db by constraints
db_matches = self.domain.find_entities(candidates)
if not db_matches:
# no results found
for slot in common.numpy.random.choice(
list(filtered_slot_values.keys()),
min(5, len(filtered_slot_values)), replace=False):
act.add_value(slot, filtered_slot_values[slot])
act.add_value(self.primary_key, 'none')
return act
# don't inform about already informed entities
# -> exclude primary key values from informedPrimKeyValsSinceNone
for db_match in db_matches:
if db_match[self.primary_key] not in informedPrimKeyValsSinceNone:
for slot in common.numpy.random.choice(
list(filtered_slot_values.keys()),
min(4 - (len(act.slot_values)), len(filtered_slot_values)), replace=False):
act.add_value(slot, filtered_slot_values[slot])
additional_constraints = {}
for slot, value in candidates.items():
if len(act.slot_values) < 4 and value == 'dontcare':
additional_constraints[slot] = value
db_match = self.domain.find_info_about_entity(
db_match[self.primary_key],
requested_slots=self.domain.get_informable_slots())[0]
self._db_results_to_sysact(act, additional_constraints, db_match)
return act
# no alternatives found (that were not already mentioned)
act.add_value(self.primary_key, 'none')
return act
def _expand_bye(self):
""" Expand bye action """
act = SysAct()
act.type = SysActionType.Bye
return act
def _expand_reqmore(self):
""" Expand reqmore action """
act = SysAct()
act.type = SysActionType.RequestMore
return act
def expand_system_action(self, action_idx: int, beliefstate: BeliefState):
""" Expands an action index to a real sytem act """
action_name = self.action_name(action_idx)
if 'request#' in action_name:
return self._expand_request(action_name)
elif 'select#' in action_name:
return self._expand_select(action_name, beliefstate)
elif 'confirm#' in action_name:
return self._expand_confirm(action_name, beliefstate)
elif 'confreq#' in action_name:
return self._expand_confreq(action_name, beliefstate)
elif action_name == 'inform_byname':
return self._expand_informbyname(beliefstate)
elif action_name == 'inform_alternatives':
return self._expand_informbyalternatives(beliefstate)
elif action_name == 'closingmsg':
return self._expand_bye()
elif action_name == 'repeat':
return self.last_sys_act
elif action_name == 'reqmore':
return self._expand_reqmore()
elif self.logger:
self.logger.warning("RL POLICY: system action not supported: " +
action_name)
return None
def _update_system_belief(self, beliefstate: BeliefState, sys_act: SysAct):
""" Update the system's belief state features """
# check if system informed an entity primary key vlaue this turn
# (and remember it), otherwise, keep previous one
# reset, if primary key value == none (found no matching entities)
self.sys_state['lastActionInformNone'] = False
self.sys_state['offerHappened'] = False
informed_names = sys_act.get_values(self.primary_key)
if len(informed_names) > 0 and len(informed_names[0]) > 0:
self.sys_state['offerHappened'] = True
self.sys_state['lastInformedPrimKeyVal'] = informed_names[0]
if informed_names[0] == 'none':
self.sys_state['informedPrimKeyValsSinceNone'] = []
self.sys_state['lastActionInformNone'] = True
else:
self.sys_state['informedPrimKeyValsSinceNone'].append(informed_names[0])
# set last system request slot s.t. BST can infer "this" reference
# from user simulator (where user inform omits primary key slot)
if sys_act.type == SysActionType.Request or sys_act.type == SysActionType.Select:
self.sys_state['lastRequestSlot'] = list(sys_act.slot_values.keys())[0]
elif sys_act.type == SysActionType.ConfirmRequest:
req_slot = None
for slot in sys_act.slot_values:
if len(sys_act.get_values(slot)) == 0:
req_slot = slot
break
self.sys_state['lastRequestSlot'] = req_slot
else:
self.sys_state['lastRequestSlot'] = None
def turn_end(self, beliefstate: BeliefState, state_vector: torch.FloatTensor,
sys_act_idx: int):
""" Call this function after a turn is done by the system """
self.last_sys_act = self.expand_system_action(sys_act_idx, beliefstate)
if self.logger:
self.logger.dialog_turn("system action > " + str(self.last_sys_act))
self._update_system_belief(beliefstate, self.last_sys_act)
turn_reward = self.evaluator.get_turn_reward()
if self.is_training:
self.buffer.store(state_vector, sys_act_idx, turn_reward, terminal=False)
def _expand_hello(self):
""" Call this function when a dialog begins """
hello_action = SysAct()
hello_action.type = SysActionType.Welcome
self.last_sys_act = hello_action
if self.logger:
self.logger.dialog_turn("system action > " + str(hello_action))
return {'sys_act': hello_action}
def end_dialog(self, sim_goal: Goal):
""" Call this function when a dialog ended """
if sim_goal is None:
# real user interaction, no simulator - don't have to evaluate
# anything, just reset counters
return
final_reward, success = self.evaluator.get_final_reward(sim_goal, logging=False)
if self.is_training:
self.buffer.store(None, None, final_reward, terminal=True)
# if self.writer is not None:
# self.writer.add_scalar('buffer/items', len(self.buffer),
# self.train + self.total_train_dialogs)
__init__(self, domain, buffer_cls=<class 'services.policy.rl.experience_buffer.UniformBuffer'>, buffer_size=6000, batch_size=64, discount_gamma=0.99, max_turns=25, include_confreq=False, logger=<DiasysLogger adviser (NOTSET)>, include_select=False, device=device(type='cpu'))
special
¶
Creates state- and action spaces, initializes experience replay buffers.
Keyword arguments:
Name | Type | Description |
---|---|---|
subgraph |
{[type]} -- [see services.Module] (default |
{None}) |
-- |
[Experience replay buffer *class*, **not** an instance - will be
initialized by this constructor!] (default |
{UniformBuffer}) |
buffer_size |
{int} -- [see services.policy.rl.experience_buffer.
Buffer] (default |
{6000}) |
batch_size |
{int} -- [see services.policy.rl.experience_buffer.
Buffer] (default |
{64}) |
discount_gamma |
{float} -- [Discount factor] (default |
{0.99}) |
include_confreq |
{bool} -- [Use confirm_request actions]
(default |
{False}) |
Source code in adviser/services/policy/rl/policy_rl.py
def __init__(self, domain: JSONLookupDomain, buffer_cls=UniformBuffer,
buffer_size=6000, batch_size=64, discount_gamma=0.99, max_turns: int = 25,
include_confreq=False, logger: DiasysLogger = DiasysLogger(),
include_select: bool = False, device=torch.device('cpu')):
"""
Creates state- and action spaces, initializes experience replay
buffers.
Arguments:
domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
Keyword Arguments:
subgraph {[type]} -- [see services.Module] (default: {None})
buffer_cls {services.policy.rl.experience_buffer.Buffer}
-- [Experience replay buffer *class*, **not** an instance - will be
initialized by this constructor!] (default: {UniformBuffer})
buffer_size {int} -- [see services.policy.rl.experience_buffer.
Buffer] (default: {6000})
batch_size {int} -- [see services.policy.rl.experience_buffer.
Buffer] (default: {64})
discount_gamma {float} -- [Discount factor] (default: {0.99})
include_confreq {bool} -- [Use confirm_request actions]
(default: {False})
"""
self.device = device
self.sys_state = {
"lastInformedPrimKeyVal": None,
"lastActionInformNone": False,
"offerHappened": False,
'informedPrimKeyValsSinceNone': []}
self.max_turns = max_turns
self.logger = logger
self.domain = domain
# setup evaluator for training
self.evaluator = ObjectiveReachedEvaluator(domain, logger=logger)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.discount_gamma = discount_gamma
self.writer = None
# get state size
self.state_dim = self.beliefstate_dict_to_vector(
BeliefState(domain)._init_beliefstate()).size(1)
self.logger.info("state space dim: " + str(self.state_dim))
# get system action list
self.actions = ["inform_byname", # TODO rename to 'bykey'
"inform_alternatives",
"reqmore"]
# TODO badaction
for req_slot in self.domain.get_system_requestable_slots():
self.actions.append('request#' + req_slot)
self.actions.append('confirm#' + req_slot)
if include_select:
self.actions.append('select#' + req_slot)
if include_confreq:
for conf_slot in self.domain.get_system_requestable_slots():
if not req_slot == conf_slot:
# skip case where confirm slot = request slot
self.actions.append('confreq#' + conf_slot + '#' +
req_slot)
self.action_dim = len(self.actions)
# don't include closingmsg in learnable actions
self.actions.append('closingmsg')
# self.actions.append("closingmsg")
self.logger.info("action space dim: " + str(self.action_dim))
self.primary_key = self.domain.get_primary_key()
# init replay memory
self.buffer = buffer_cls(buffer_size, batch_size, self.state_dim,
discount_gamma=discount_gamma, device=device)
self.sys_state = {}
self.last_sys_act = None
action_idx(self, action_name)
¶
action_name(self, action_idx)
¶
beliefstate_dict_to_vector(self, beliefstate)
¶
Converts the beliefstate dict to a torch tensor
Parameters:
Name | Type | Description | Default |
---|---|---|---|
beliefstate |
BeliefState |
dict of belief (with at least beliefs and system keys) |
required |
Returns:
Type | Description |
---|---|
belief tensor with dimension 1 x state_dim |
Source code in adviser/services/policy/rl/policy_rl.py
def beliefstate_dict_to_vector(self, beliefstate: BeliefState):
""" Converts the beliefstate dict to a torch tensor
Args:
beliefstate: dict of belief (with at least beliefs and system keys)
Returns:
belief tensor with dimension 1 x state_dim
"""
belief_vec = []
# add user acts
belief_vec += [1 if act in beliefstate['user_acts'] else 0 for act in UserActionType]
# handle none actions
belief_vec.append(1 if sum(belief_vec) == 0 else 1)
# add informs (including special flag if slot not mentioned)
for slot in sorted(self.domain.get_informable_slots()):
values = self.domain.get_possible_values(slot) + ["dontcare"]
if slot not in beliefstate['informs']:
# add **NONE** value first, then 0.0 for all others
belief_vec.append(1.0)
# also add value for don't care
belief_vec += [0 for i in range(len(values))]
else:
# add **NONE** value first
belief_vec.append(0.0)
bs_slot = beliefstate['informs'][slot]
belief_vec += [bs_slot[value] if value in bs_slot else 0.0 for value in values]
# add requests
for slot in sorted(self.domain.get_requestable_slots()):
if slot in beliefstate['requests']:
belief_vec.append(1.0)
else:
belief_vec.append(0.0)
# append system features
belief_vec.append(float(self.sys_state['lastActionInformNone']))
belief_vec.append(float(self.sys_state['offerHappened']))
candidate_count = beliefstate['num_matches']
# buckets for match count: 0, 1, 2-4, >4
belief_vec.append(float(candidate_count == 0))
belief_vec.append(float(candidate_count == 1))
belief_vec.append(float(2 <= candidate_count <= 4))
belief_vec.append(float(candidate_count > 4))
belief_vec.append(float(beliefstate["discriminable"]))
# convert to torch tensor
return torch.tensor([belief_vec], dtype=torch.float, device=self.device)
end_dialog(self, sim_goal)
¶
Call this function when a dialog ended
Source code in adviser/services/policy/rl/policy_rl.py
def end_dialog(self, sim_goal: Goal):
""" Call this function when a dialog ended """
if sim_goal is None:
# real user interaction, no simulator - don't have to evaluate
# anything, just reset counters
return
final_reward, success = self.evaluator.get_final_reward(sim_goal, logging=False)
if self.is_training:
self.buffer.store(None, None, final_reward, terminal=True)
# if self.writer is not None:
# self.writer.add_scalar('buffer/items', len(self.buffer),
# self.train + self.total_train_dialogs)
expand_system_action(self, action_idx, beliefstate)
¶
Expands an action index to a real sytem act
Source code in adviser/services/policy/rl/policy_rl.py
def expand_system_action(self, action_idx: int, beliefstate: BeliefState):
""" Expands an action index to a real sytem act """
action_name = self.action_name(action_idx)
if 'request#' in action_name:
return self._expand_request(action_name)
elif 'select#' in action_name:
return self._expand_select(action_name, beliefstate)
elif 'confirm#' in action_name:
return self._expand_confirm(action_name, beliefstate)
elif 'confreq#' in action_name:
return self._expand_confreq(action_name, beliefstate)
elif action_name == 'inform_byname':
return self._expand_informbyname(beliefstate)
elif action_name == 'inform_alternatives':
return self._expand_informbyalternatives(beliefstate)
elif action_name == 'closingmsg':
return self._expand_bye()
elif action_name == 'repeat':
return self.last_sys_act
elif action_name == 'reqmore':
return self._expand_reqmore()
elif self.logger:
self.logger.warning("RL POLICY: system action not supported: " +
action_name)
return None
turn_end(self, beliefstate, state_vector, sys_act_idx)
¶
Call this function after a turn is done by the system
Source code in adviser/services/policy/rl/policy_rl.py
def turn_end(self, beliefstate: BeliefState, state_vector: torch.FloatTensor,
sys_act_idx: int):
""" Call this function after a turn is done by the system """
self.last_sys_act = self.expand_system_action(sys_act_idx, beliefstate)
if self.logger:
self.logger.dialog_turn("system action > " + str(self.last_sys_act))
self._update_system_belief(beliefstate, self.last_sys_act)
turn_reward = self.evaluator.get_turn_reward()
if self.is_training:
self.buffer.store(state_vector, sys_act_idx, turn_reward, terminal=False)
train_dqnpolicy
¶
This script can be executed to train a DQN policy.¶
It will create a policy model (file ending with .pt).¶
You need to execute this script before you can interact with the RL agent.¶
¶
get_root_dir()
¶
train(domain_name, log_to_file, seed, train_epochs, train_dialogs, eval_dialogs, max_turns, train_error_rate, test_error_rate, lr, eps_start, grad_clipping, buffer_classname, buffer_size, use_tensorboard)
¶
Training loop for the RL policy, for information on the parameters, look at the descriptions of commandline arguments in the "if main" below
Source code in adviser/services/policy/rl/train_dqnpolicy.py
def train(domain_name: str, log_to_file: bool, seed: int, train_epochs: int, train_dialogs: int,
eval_dialogs: int, max_turns: int, train_error_rate: float, test_error_rate: float,
lr: float, eps_start: float, grad_clipping: float, buffer_classname: str,
buffer_size: int, use_tensorboard: bool):
"""
Training loop for the RL policy, for information on the parameters, look at the descriptions
of commandline arguments in the "if main" below
"""
seed = seed if seed != -1 else None
common.init_random(seed=seed)
file_log_lvl = LogLevel.DIALOGS if log_to_file else LogLevel.NONE
logger = DiasysLogger(console_log_lvl=LogLevel.RESULTS, file_log_lvl=file_log_lvl)
summary_writer = SummaryWriter(log_dir='logs') if use_tensorboard else None
if buffer_classname == "prioritized":
buffer_cls = NaivePrioritizedBuffer
elif buffer_classname == "uniform":
buffer_cls = UniformBuffer
domain = JSONLookupDomain(name=domain_name)
bst = HandcraftedBST(domain=domain, logger=logger)
user = HandcraftedUserSimulator(domain, logger=logger)
# noise = SimpleNoise(domain=domain, train_error_rate=train_error_rate,
# test_error_rate=test_error_rate, logger=logger)
policy = DQNPolicy(domain=domain, lr=lr, eps_start=eps_start,
gradient_clipping=grad_clipping, buffer_cls=buffer_cls,
replay_buffer_size=buffer_size, train_dialogs=train_dialogs,
logger=logger, summary_writer=summary_writer)
evaluator = PolicyEvaluator(domain=domain, use_tensorboard=use_tensorboard,
experiment_name=domain_name, logger=logger,
summary_writer=summary_writer)
ds = DialogSystem(services=[user, bst, policy, evaluator], protocol='tcp')
# ds.draw_system_graph()
error_free = ds.is_error_free_messaging_pipeline()
if not error_free:
ds.print_inconsistencies()
for j in range(train_epochs):
# START TRAIN EPOCH
evaluator.train()
policy.train()
evaluator.start_epoch()
for episode in range(train_dialogs):
if episode % 100 == 0:
print("DIALOG", episode)
logger.dialog_turn("\n\n!!!!!!!!!!!!!!!! NEW DIALOG !!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
ds.run_dialog(start_signals={f'user_acts/{domain.get_domain_name()}': []})
evaluator.end_epoch()
policy.save()
# START EVAL EPOCH
evaluator.eval()
policy.eval()
evaluator.start_epoch()
for episode in range(eval_dialogs):
logger.dialog_turn("\n\n!!!!!!!!!!!!!!!! NEW DIALOG !!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
ds.run_dialog(start_signals={f'user_acts/{domain.get_domain_name()}': []})
evaluator.end_epoch()
ds.shutdown()
service
¶
DialogSystem
¶
This class will constrct a dialog system from the list of services provided to the constructor. It will also handle synchronization for initalization of services before dialog start / after dialog end / on system shutdown and lets you discover potential conflicts in you messaging pipeline. This class is also used to communicate / synchronize with services running on different nodes.
Source code in adviser/services/service.py
class DialogSystem:
"""
This class will constrct a dialog system from the list of services provided to the constructor.
It will also handle synchronization for initalization of services before dialog start / after dialog end / on system shutdown
and lets you discover potential conflicts in you messaging pipeline.
This class is also used to communicate / synchronize with services running on different nodes.
"""
def __init__(self, services: List[Union[Service, RemoteService]], sub_port: int = 65533, pub_port: int = 65534,
reg_port: int = 65535, protocol: str = 'tcp', debug_logger: DiasysLogger = None):
"""
Args:
services (List[Union[Service, RemoteService]]): List of all (remote) services to connect to.
Only once they're specified here will they start listening for
messages.
sub_port(int): subscriber port
sub_addr(str): IP-address or domain name of proxy subscriber interface (e.g. 127.0.0.1 for your local machine)
pub_port(int): publisher port
pub_addr(str): IP-address or domain name of proxy publisher interface (e.g. 127.0.0.1 for your local machine)
reg_port (int): registration port for remote services
protocol(str): communication protol, either 'inproc' or 'tcp' or `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`
"""
# node-local topics
self.debug_logger = debug_logger
self.protocol = protocol
self._sub_topics = {}
self._pub_topics = {}
self._remote_identifiers = set()
self._services = [] # collects names and instances of local services
self._start_dialog_services = set() # collects names of local services that subscribe to dialog_start
# node-local sockets
self._domains = set()
# start proxy thread
self._proxy_dev = ProcessProxy(in_type=zmq.XSUB, out_type=zmq.XPUB) # , mon_type=zmq.XSUB)
self._proxy_dev.bind_in(f"{protocol}://127.0.0.1:{pub_port}")
self._proxy_dev.bind_out(f"{protocol}://127.0.0.1:{sub_port}")
self._proxy_dev.start()
self._sub_port = sub_port
self._pub_port = pub_port
# thread control
self._start_topics = set()
self._end_topics = set()
self._terminate_topics = set()
self._stopEvent = threading.Event()
# control channels
ctx = Context.instance()
self._control_channel_pub = ctx.socket(zmq.PUB)
self._control_channel_pub.sndhwm = 1100000
self._control_channel_pub.connect(f"{protocol}://127.0.0.1:{pub_port}")
self._control_channel_sub = ctx.socket(zmq.SUB)
# register services (local and remote)
remote_services = {}
for service in services:
if isinstance(service, Service):
# register local service
service_name = type(service).__name__ if service._identifier is None else service._identifier
service._init_pubsub()
self._add_service_info(service_name, service._domain_name, service._sub_topics, service._pub_topics,
service._start_topic, service._end_topic, service._terminate_topic)
service._register_with_dialogsystem()
elif isinstance(service, RemoteService):
remote_services[getattr(service, 'identifier')] = service
self._register_remote_services(remote_services, reg_port)
self._control_channel_sub.connect(f"{protocol}://127.0.0.1:{sub_port}")
self._setup_dialog_end_listener()
time.sleep(0.25)
def _register_pub_topic(self, publisher, topic: str):
""" Map a publisher instance to a topic """
if not topic in self._pub_topics:
self._pub_topics[topic] = set()
self._pub_topics[topic].add(publisher)
def _register_sub_topic(self, subscriber, topic):
""" Map a subscriber instance to a topic """
if not topic in self._sub_topics:
self._sub_topics[topic] = set()
self._sub_topics[topic].add(subscriber)
def _register_remote_services(self, remote_services: List[RemoteService], reg_port: int):
"""
Register all remote services.
*Blocking* until an ACK was received from all of them, confirming they're setup and ready.
Args:
remote_services (List[RemoteService]): list of all remote services to register
reg_port (int): registration port for remote services
"""
if len(remote_services) == 0:
return # nothing to register
# Socket to receive registration requests
ctx = Context.instance()
reg_service = ctx.socket(zmq.REP)
reg_service.bind(f'tcp://127.0.0.1:{reg_port}')
while len(remote_services) > 0:
# call next remote service
msg, data = reg_service.recv_multipart()
msg = msg.decode("utf-8")
if msg.startswith("REGISTER_"):
# make sure we have a register message
remote_service_identifier = msg[len("REGISTER_"):]
if remote_service_identifier in remote_services:
print(f"registering service {remote_service_identifier}...")
# add remote service interface info
domain_name, sub_topics, pub_topics, start_topic, end_topic, terminate_topic = pickle.loads(data)
self._add_service_info(remote_service_identifier, domain_name, sub_topics, pub_topics, start_topic,
end_topic, terminate_topic)
self._remote_identifiers.add(remote_service_identifier)
# acknowledge service registration
reg_service.send(bytes(f'ACK_REGISTER_{remote_service_identifier}', encoding="ascii"))
elif msg.startswith("CONF_REGISTER_"):
# complete registration
remote_service_identifier = msg[len("CONF_REGISTER_"):]
if remote_service_identifier in remote_services:
del remote_services[remote_service_identifier]
print(f"successfully registered service {remote_service_identifier}")
reg_service.send(bytes(f"", encoding="ascii"))
print("########## Finished registering all remote services ##########")
def _add_service_info(self, service_name: str, domain_name: str, sub_topics: List[str], pub_topics: List[str],
start_topic: str, end_topic:str, terminate_topic: str):
""" Add all relevant info from a service (needed to construct dialog graph for debugging).
Also, sets up all required control channels for this service based on the service's info.
Args:
service_name (str): service name
domain_name (str): domain name
sub_topics (List[str]): list of all subscribed to topics of the given service
pub_topics (List[str]): list of all topics the given service publishes to
start_topic (str): control channel topic for setting given service into `listening` mode
end_topic (str): control channel topic for setting given service into `non-listening` mode
terminate_topic (str): control channel topic for stopping given service's listener loops and
closing the listener sockets
"""
self._domains.add(domain_name)
for topic in sub_topics:
self._register_sub_topic(service_name, topic)
for topic in pub_topics:
self._register_pub_topic(service_name, topic)
# setup control channels
self._start_topics.add(start_topic)
self._end_topics.add(end_topic)
self._terminate_topics.add(terminate_topic)
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{start_topic}", encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{end_topic}", encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(f"ACK/{terminate_topic}", encoding="ascii"))
def _setup_dialog_end_listener(self):
""" Creates socket for listening to Topic.DIALOG_END messages """
ctx = Context.instance()
self._end_socket = ctx.socket(zmq.SUB)
# subscribe to dialog end from all domains
self._end_socket.setsockopt(zmq.SUBSCRIBE, bytes(Topic.DIALOG_END, encoding="ascii"))
self._end_socket.connect(f"{self.protocol}://127.0.0.1:{self._sub_port}")
# # add to list of local topics
# if Topic.DIALOG_END not in self._local_sub_topics:
# self._local_sub_topics[Topic.DIALOG_END] = set()
# self._local_sub_topics[Topic.DIALOG_END].add(type(self).__name__)
def stop(self):
""" Set stop event (can be queried by services via the `terminating()` function) """
self._stopEvent.set()
pass
def terminating(self):
""" Returns True if the system is stopping, else False """
return self._stopEvent.is_set()
def shutdown(self):
""" Shutdown dialog system.
This will trigger `terminate` messages to be sent to all registered services to stop their listener loops.
Should be called in the end before exiting your program.
Blocks until all services sent ACK's confirming they're stopped.
"""
self._stopEvent.set()
for terminate_topic in self._terminate_topics:
_send_msg(self._control_channel_pub, terminate_topic, True)
_recv_ack(self._control_channel_sub, terminate_topic)
def _end_dialog(self):
""" Block until all receivers stopped listening.
Then, calls `dialog_end` on all registered services. """
# listen for Topic.DIALOG_END messages
while True:
try:
msg = self._end_socket.recv_multipart(copy=True)
# receive message for subscribed topic
topic = msg[0].decode("ascii")
timestamp, content = pickle.loads(msg[1])
if content:
if self.debug_logger:
self.debug_logger.info(f"- (DS): received DIALOG_END message in _end_dialog from topic {topic}")
self.stop()
break
except KeyboardInterrupt:
break
except:
import traceback
traceback.print_exc()
print("ERROR in _end_dialog ")
# stop receivers (blocking)
for end_topic in self._end_topics:
_send_msg(self._control_channel_pub, end_topic, True)
_recv_ack(self._control_channel_sub, end_topic)
if self.debug_logger:
self.debug_logger.info(f"- (DS): all services STOPPED listening")
def _start_dialog(self, start_signals: dict):
""" Block until all receivers started listening.
Then, call `dialog_start`on all registered services.
Finally, publish all start signals given. """
self._stopEvent.clear()
# start receivers (blocking)
for start_topic in self._start_topics:
_send_msg(self._control_channel_pub, start_topic, True)
_recv_ack(self._control_channel_sub, start_topic)
if self.debug_logger:
self.debug_logger.info(f"- (DS): all services STARTED listening")
# publish first turn trigger
# for domain in self._domains:
# "wildcard" mechanism: publish start messages to all known domains
for topic in start_signals:
_send_msg(self._control_channel_pub, f"{topic}", start_signals[topic])
def run_dialog(self, start_signals: dict = {Topic.DIALOG_END: False}):
""" Run a complete dialog (blocking).
Dialog will be started via messages to the topics specified in `start_signals`.
The dialog will end on receiving any `Topic.DIALOG_END` message with value 'True',
so make sure at least one service in your dialog graph will publish this message eventually.
Args:
start_signals (Dict[str, Any]): mapping from topic -> value
Publishes the value given for each topic to the respective topic.
Use this to trigger the start of your dialog system.
"""
self._start_dialog(start_signals)
self._end_dialog()
def list_published_topics(self):
""" Get all declared publisher topics.
Returns:
A dictionary with mapping
topic (str) -> publishing services (Set[str]).
Note:
* Call this method after instantiating all services.
* Even though a publishing topic might be listed here, there is no guarantee that
its publisher(s) might ever publish to it.
"""
return copy.deepcopy(self._pub_topics) # copy s.t. no user changes this list
def list_subscribed_topics(self):
""" Get all declared subscribed topics.
Returns:
A dictionary with mapping
topic (str) -> subscribing services (Set[str]).
Notes:
* Call this method after instantiating all services.
"""
return copy.deepcopy(self._sub_topics) # copy s.t. no user changes this list
def draw_system_graph(self, name: str = 'system', format: str = "png", show: bool = True):
""" Draws a graph of the system as a directed graph.
Services are represented by nodes, messages by directed edges (from publisher to subscriber).
Warnings are drawn as yellow edges (and the missing subscribers represented by an 'UNCONNECTED SERVICES' node),
errors as red edges (and the missing publishers represented by the 'UNCONNECTED SERVICES' node as well).
Will mark remote services with blue.
Args:
name (str): used to construct the name of your output file
format (str): output file format (e.g. png, pdf, jpg, ...)
show (bool): if True, the graph image will be opened in your default image viewer application
Requires:
graphviz library (pip install graphviz)
"""
from graphviz import Digraph
g = Digraph(name=name, format=format)
# collect all services, errors and warnings
services = set()
for service_set in self._pub_topics.values():
services = services.union(service_set)
for service_set in self._sub_topics.values():
services = services.union(service_set)
errors, warnings = self.list_inconsistencies()
# add services as nodes
for service in services:
if service in self._remote_identifiers:
g.node(service, color='#1f618d', style='filled', fontcolor='white', shape='box') # remote service
else:
g.node(service, color='#1c2833', shape='box') # local service
if len(errors) > 0 or len(warnings) > 0:
g.node('UNCONNECTED SERVICES', style='filled', color='#922b21', fontcolor='white', shape='box')
# draw connections from publisher to subscribers as edges
for topic in self._pub_topics:
publishers = self._pub_topics[topic]
receivers = self._sub_topics[topic] if topic in self._sub_topics else []
for receiver in receivers:
for publisher in publishers:
g.edge(publisher, receiver, label=topic)
# draw warnings and errors as edges to node 'UNCONNECTED SERVICES'
for topic in errors:
receivers = errors[topic]
for receiver in receivers:
g.edge('UNCONNECTED SERVICES', receiver, color='#c34400', fontcolor='#c34400', label=topic)
for topic in warnings:
publishers = warnings[topic]
for publisher in publishers:
g.edge(publisher, 'UNCONNECTED SERVICES', color='#e37c02', fontcolor='#e37c02', label=topic)
# draw graph
g.render(view=show, cleanup=True)
def list_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_inconsistencies()[0]) == 0 -> error free pipeline
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Returns:
A touple of dictionaries:
* the first dictionary contains potential errors (with the mapping topics -> subsribing services)
* the second dictionary contains warnings (with the mapping topics -> publishing services).
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# look for subscribers w/o publishers by checking topic prefixes
errors = {}
for sub_topic in self._sub_topics:
found_pub = False
for pub_topic in self._pub_topics:
if pub_topic.startswith(sub_topic):
found_pub = True
break
if not found_pub:
errors[sub_topic] = self._sub_topics[sub_topic]
# look for publishers w/o subscribers by checking topic prefixes
warnings = {}
for pub_topic in self._pub_topics:
found_sub = False
for sub_topic in self._sub_topics:
if pub_topic.startswith(sub_topic):
found_sub = True
break
if not found_sub:
warnings[pub_topic] = self._pub_topics[pub_topic]
return errors, warnings
def print_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_local_inconsistencies()[0]) == 0 -> error free pipeline and prints them
to the console.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# console colors
WARNING = '\033[93m'
ERROR = '\033[91m'
ENDC = '\033[0m'
errors, warnings = self.list_inconsistencies()
print(ERROR)
print("(Potential) Errors (subscribed topics without publishers):")
for topic in errors:
print(f" topic: '{topic}', subscribed to in services: {errors[topic]}")
print(ENDC)
print(WARNING)
print("Warnings (published topics without subscribers):")
for topic in warnings:
print(f" topic: '{topic}', published in services: {warnings[topic]}")
print(ENDC)
def is_error_free_messaging_pipeline(self) -> bool:
""" Checks the current messaging pipeline for potential errors.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Returns:
True, if no potential errors could be found - else, False
Notes:
* Call this method after instantiating all services.
* Lists only node-local (or process-local) inconsistencies.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
return len(self.list_inconsistencies()[0]) == 0
__init__(self, services, sub_port=65533, pub_port=65534, reg_port=65535, protocol='tcp', debug_logger=None)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
services |
List[Union[Service, RemoteService]] |
List of all (remote) services to connect to. Only once they're specified here will they start listening for messages. |
required |
sub_port(int) |
subscriber port |
required | |
sub_addr(str) |
IP-address or domain name of proxy subscriber interface (e.g. 127.0.0.1 for your local machine) |
required | |
pub_port(int) |
publisher port |
required | |
pub_addr(str) |
IP-address or domain name of proxy publisher interface (e.g. 127.0.0.1 for your local machine) |
required | |
reg_port |
int |
registration port for remote services |
65535 |
protocol(str) |
communication protol, either 'inproc' or 'tcp' or |
required | |
debug_logger |
DiasysLogger |
If not |
None |
Source code in adviser/services/service.py
def __init__(self, services: List[Union[Service, RemoteService]], sub_port: int = 65533, pub_port: int = 65534,
reg_port: int = 65535, protocol: str = 'tcp', debug_logger: DiasysLogger = None):
"""
Args:
services (List[Union[Service, RemoteService]]): List of all (remote) services to connect to.
Only once they're specified here will they start listening for
messages.
sub_port(int): subscriber port
sub_addr(str): IP-address or domain name of proxy subscriber interface (e.g. 127.0.0.1 for your local machine)
pub_port(int): publisher port
pub_addr(str): IP-address or domain name of proxy publisher interface (e.g. 127.0.0.1 for your local machine)
reg_port (int): registration port for remote services
protocol(str): communication protol, either 'inproc' or 'tcp' or `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`
"""
# node-local topics
self.debug_logger = debug_logger
self.protocol = protocol
self._sub_topics = {}
self._pub_topics = {}
self._remote_identifiers = set()
self._services = [] # collects names and instances of local services
self._start_dialog_services = set() # collects names of local services that subscribe to dialog_start
# node-local sockets
self._domains = set()
# start proxy thread
self._proxy_dev = ProcessProxy(in_type=zmq.XSUB, out_type=zmq.XPUB) # , mon_type=zmq.XSUB)
self._proxy_dev.bind_in(f"{protocol}://127.0.0.1:{pub_port}")
self._proxy_dev.bind_out(f"{protocol}://127.0.0.1:{sub_port}")
self._proxy_dev.start()
self._sub_port = sub_port
self._pub_port = pub_port
# thread control
self._start_topics = set()
self._end_topics = set()
self._terminate_topics = set()
self._stopEvent = threading.Event()
# control channels
ctx = Context.instance()
self._control_channel_pub = ctx.socket(zmq.PUB)
self._control_channel_pub.sndhwm = 1100000
self._control_channel_pub.connect(f"{protocol}://127.0.0.1:{pub_port}")
self._control_channel_sub = ctx.socket(zmq.SUB)
# register services (local and remote)
remote_services = {}
for service in services:
if isinstance(service, Service):
# register local service
service_name = type(service).__name__ if service._identifier is None else service._identifier
service._init_pubsub()
self._add_service_info(service_name, service._domain_name, service._sub_topics, service._pub_topics,
service._start_topic, service._end_topic, service._terminate_topic)
service._register_with_dialogsystem()
elif isinstance(service, RemoteService):
remote_services[getattr(service, 'identifier')] = service
self._register_remote_services(remote_services, reg_port)
self._control_channel_sub.connect(f"{protocol}://127.0.0.1:{sub_port}")
self._setup_dialog_end_listener()
time.sleep(0.25)
draw_system_graph(self, name='system', format='png', show=True)
¶
Draws a graph of the system as a directed graph. Services are represented by nodes, messages by directed edges (from publisher to subscriber). Warnings are drawn as yellow edges (and the missing subscribers represented by an 'UNCONNECTED SERVICES' node), errors as red edges (and the missing publishers represented by the 'UNCONNECTED SERVICES' node as well). Will mark remote services with blue.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
used to construct the name of your output file |
'system' |
format |
str |
output file format (e.g. png, pdf, jpg, ...) |
'png' |
show |
bool |
if True, the graph image will be opened in your default image viewer application |
True |
Requires
graphviz library (pip install graphviz)
Source code in adviser/services/service.py
def draw_system_graph(self, name: str = 'system', format: str = "png", show: bool = True):
""" Draws a graph of the system as a directed graph.
Services are represented by nodes, messages by directed edges (from publisher to subscriber).
Warnings are drawn as yellow edges (and the missing subscribers represented by an 'UNCONNECTED SERVICES' node),
errors as red edges (and the missing publishers represented by the 'UNCONNECTED SERVICES' node as well).
Will mark remote services with blue.
Args:
name (str): used to construct the name of your output file
format (str): output file format (e.g. png, pdf, jpg, ...)
show (bool): if True, the graph image will be opened in your default image viewer application
Requires:
graphviz library (pip install graphviz)
"""
from graphviz import Digraph
g = Digraph(name=name, format=format)
# collect all services, errors and warnings
services = set()
for service_set in self._pub_topics.values():
services = services.union(service_set)
for service_set in self._sub_topics.values():
services = services.union(service_set)
errors, warnings = self.list_inconsistencies()
# add services as nodes
for service in services:
if service in self._remote_identifiers:
g.node(service, color='#1f618d', style='filled', fontcolor='white', shape='box') # remote service
else:
g.node(service, color='#1c2833', shape='box') # local service
if len(errors) > 0 or len(warnings) > 0:
g.node('UNCONNECTED SERVICES', style='filled', color='#922b21', fontcolor='white', shape='box')
# draw connections from publisher to subscribers as edges
for topic in self._pub_topics:
publishers = self._pub_topics[topic]
receivers = self._sub_topics[topic] if topic in self._sub_topics else []
for receiver in receivers:
for publisher in publishers:
g.edge(publisher, receiver, label=topic)
# draw warnings and errors as edges to node 'UNCONNECTED SERVICES'
for topic in errors:
receivers = errors[topic]
for receiver in receivers:
g.edge('UNCONNECTED SERVICES', receiver, color='#c34400', fontcolor='#c34400', label=topic)
for topic in warnings:
publishers = warnings[topic]
for publisher in publishers:
g.edge(publisher, 'UNCONNECTED SERVICES', color='#e37c02', fontcolor='#e37c02', label=topic)
# draw graph
g.render(view=show, cleanup=True)
is_error_free_messaging_pipeline(self)
¶
Checks the current messaging pipeline for potential errors.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Returns:
Type | Description |
---|---|
bool |
True, if no potential errors could be found - else, False |
Notes
- Call this method after instantiating all services.
- Lists only node-local (or process-local) inconsistencies.
- Even if there are no errors returned by this method, there is not guarantee that all publishers eventually publish to their respective topics.
Source code in adviser/services/service.py
def is_error_free_messaging_pipeline(self) -> bool:
""" Checks the current messaging pipeline for potential errors.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Returns:
True, if no potential errors could be found - else, False
Notes:
* Call this method after instantiating all services.
* Lists only node-local (or process-local) inconsistencies.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
return len(self.list_inconsistencies()[0]) == 0
list_inconsistencies(self)
¶
Checks for potential errors in the current messaging pipleline: e.g. len(list_inconsistencies()[0]) == 0 -> error free pipeline
(Potential) Errors are defined in this context as subscribed topics without publishers. Warnings are defined in this context as published topics without subscribers.
Returns:
Type | Description |
---|---|
A touple of dictionaries |
|
Notes
- Call this method after instantiating all services.
- Even if there are no errors returned by this method, there is not guarantee that all publishers eventually publish to their respective topics.
Source code in adviser/services/service.py
def list_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_inconsistencies()[0]) == 0 -> error free pipeline
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Returns:
A touple of dictionaries:
* the first dictionary contains potential errors (with the mapping topics -> subsribing services)
* the second dictionary contains warnings (with the mapping topics -> publishing services).
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# look for subscribers w/o publishers by checking topic prefixes
errors = {}
for sub_topic in self._sub_topics:
found_pub = False
for pub_topic in self._pub_topics:
if pub_topic.startswith(sub_topic):
found_pub = True
break
if not found_pub:
errors[sub_topic] = self._sub_topics[sub_topic]
# look for publishers w/o subscribers by checking topic prefixes
warnings = {}
for pub_topic in self._pub_topics:
found_sub = False
for sub_topic in self._sub_topics:
if pub_topic.startswith(sub_topic):
found_sub = True
break
if not found_sub:
warnings[pub_topic] = self._pub_topics[pub_topic]
return errors, warnings
list_published_topics(self)
¶
Get all declared publisher topics.
Returns:
Type | Description |
---|---|
A dictionary with mapping topic (str) -> publishing services (Set[str]). |
Note
- Call this method after instantiating all services.
- Even though a publishing topic might be listed here, there is no guarantee that its publisher(s) might ever publish to it.
Source code in adviser/services/service.py
def list_published_topics(self):
""" Get all declared publisher topics.
Returns:
A dictionary with mapping
topic (str) -> publishing services (Set[str]).
Note:
* Call this method after instantiating all services.
* Even though a publishing topic might be listed here, there is no guarantee that
its publisher(s) might ever publish to it.
"""
return copy.deepcopy(self._pub_topics) # copy s.t. no user changes this list
list_subscribed_topics(self)
¶
Get all declared subscribed topics.
Returns:
Type | Description |
---|---|
A dictionary with mapping topic (str) -> subscribing services (Set[str]). |
Notes
- Call this method after instantiating all services.
Source code in adviser/services/service.py
print_inconsistencies(self)
¶
Checks for potential errors in the current messaging pipleline: e.g. len(list_local_inconsistencies()[0]) == 0 -> error free pipeline and prints them to the console.
(Potential) Errors are defined in this context as subscribed topics without publishers. Warnings are defined in this context as published topics without subscribers.
Notes
- Call this method after instantiating all services.
- Even if there are no errors returned by this method, there is not guarantee that all publishers eventually publish to their respective topics.
Source code in adviser/services/service.py
def print_inconsistencies(self):
""" Checks for potential errors in the current messaging pipleline:
e.g. len(list_local_inconsistencies()[0]) == 0 -> error free pipeline and prints them
to the console.
(Potential) Errors are defined in this context as subscribed topics without publishers.
Warnings are defined in this context as published topics without subscribers.
Notes:
* Call this method after instantiating all services.
* Even if there are no errors returned by this method, there is not guarantee that all publishers
eventually publish to their respective topics.
"""
# console colors
WARNING = '\033[93m'
ERROR = '\033[91m'
ENDC = '\033[0m'
errors, warnings = self.list_inconsistencies()
print(ERROR)
print("(Potential) Errors (subscribed topics without publishers):")
for topic in errors:
print(f" topic: '{topic}', subscribed to in services: {errors[topic]}")
print(ENDC)
print(WARNING)
print("Warnings (published topics without subscribers):")
for topic in warnings:
print(f" topic: '{topic}', published in services: {warnings[topic]}")
print(ENDC)
run_dialog(self, start_signals={'dialog_end': False})
¶
Run a complete dialog (blocking).
Dialog will be started via messages to the topics specified in start_signals
.
The dialog will end on receiving any Topic.DIALOG_END
message with value 'True',
so make sure at least one service in your dialog graph will publish this message eventually.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start_signals |
Dict[str, Any] |
mapping from topic -> value Publishes the value given for each topic to the respective topic. Use this to trigger the start of your dialog system. |
{'dialog_end': False} |
Source code in adviser/services/service.py
def run_dialog(self, start_signals: dict = {Topic.DIALOG_END: False}):
""" Run a complete dialog (blocking).
Dialog will be started via messages to the topics specified in `start_signals`.
The dialog will end on receiving any `Topic.DIALOG_END` message with value 'True',
so make sure at least one service in your dialog graph will publish this message eventually.
Args:
start_signals (Dict[str, Any]): mapping from topic -> value
Publishes the value given for each topic to the respective topic.
Use this to trigger the start of your dialog system.
"""
self._start_dialog(start_signals)
self._end_dialog()
shutdown(self)
¶
Shutdown dialog system.
This will trigger terminate
messages to be sent to all registered services to stop their listener loops.
Should be called in the end before exiting your program.
Blocks until all services sent ACK's confirming they're stopped.
Source code in adviser/services/service.py
def shutdown(self):
""" Shutdown dialog system.
This will trigger `terminate` messages to be sent to all registered services to stop their listener loops.
Should be called in the end before exiting your program.
Blocks until all services sent ACK's confirming they're stopped.
"""
self._stopEvent.set()
for terminate_topic in self._terminate_topics:
_send_msg(self._control_channel_pub, terminate_topic, True)
_recv_ack(self._control_channel_sub, terminate_topic)
stop(self)
¶
terminating(self)
¶
RemoteService
¶
This is a placeholderto be used in the service list argument when constructing a
DialogSystem:
* Run the real
Serviceinstance on a remote node, give it a *UNIQUE* identifier
* call
run_standalone()on this instance
* Instantiate a remote service on the node about to run the
DialogSystem, assign the *SAME* identifier to it
* add it to the
DialogSystemservice list
* Now, when calling the constructor of
DialogSystem`, you should see messages informing you about the
successfull connection, or if the system is still trying to connect, it will block until connected to
the remote service.
Source code in adviser/services/service.py
class RemoteService:
"""
This is a placeholder` to be used in the service list argument when constructing a `DialogSystem`:
* Run the real `Service` instance on a remote node, give it a *UNIQUE* identifier
* call `run_standalone()` on this instance
* Instantiate a remote service on the node about to run the `DialogSystem`, assign the *SAME* identifier to it
* add it to the `DialogSystem` service list
* Now, when calling the constructor of `DialogSystem`, you should see messages informing you about the
successfull connection, or if the system is still trying to connect, it will block until connected to
the remote service.
"""
def __init__(self, identifier: str):
"""
Args:
identifier (str): the *UNIQUE* identifier to call the remote service instance
"""
self.identifier = identifier
__init__(self, identifier)
special
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
identifier |
str |
the UNIQUE identifier to call the remote service instance |
required |
Service
¶
Service base class. Inherit from this class, if you want to publish / subscribe to topics (Don't forget to call the super constructor!). You may decorate arbitrary functions in the child class with the services.service.PublishSubscribe decorator for this purpose.
A Service
will only start listening to messages once it is added to a DialogSystem
(or calling run_standalone()
in the remote case and adding a corresponding RemoteService
to the DialogSystem
).
Source code in adviser/services/service.py
class Service:
"""
Service base class.
Inherit from this class, if you want to publish / subscribe to topics *(Don't forget to call the super constructor!)*.
You may decorate arbitrary functions in the child class with the services.service.PublishSubscribe decorator
for this purpose.
Note: A `Service` will only start listening to messages once it is added to a `DialogSystem`
(or calling `run_standalone()` in the remote case and adding a corresponding `RemoteService` to the `DialogSystem`).
"""
def __init__(self, domain: Union[str, Domain] = "", sub_topic_domains: Dict[str, str] = {}, pub_topic_domains: Dict[str, str] = {},
ds_host_addr: str = "127.0.0.1", sub_port: int = 65533, pub_port: int = 65534, protocol: str = "tcp",
debug_logger: DiasysLogger = None, identifier: str = None):
"""
Create a new service instance *(call this super constructor from your inheriting classes!)*.
Args:
domain (Union[str, Domain]): The domain(-name) of your service (or empty string, if domain-agnostic).
If a domain(-name) is set, it will automatically filter out all messages from other domains.
If no domain(-name) is set, messages from all domains will be received.
sub_topic_domains (Dict[str, str]): change subscribed to topics to listen to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
pub_topic_domains (Dict[str, str]): change published topics to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
ds_host_addr (str): IP-address of the parent `DialogSystem` (default: localhost)
sub_port (int): subscriber port following zmq's XSUB/XPUB pattern
pub_port (int): publisher port following zmq's XSUB/XPUB pattern
protocol (string): communication protocol with `DialogSystem` - has to match!
Possible options: `tcp`, `inproc`, `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`.
identifier (str): Set this to a *UNIQUE* identifier per service to be run remotely.
See `RemoteService` for more details.
"""
self.is_training = False
self.domain = domain
# get domain name (gets appended to all sub/pub topics so that different domain topics don't get shared)
if domain is not None:
self._domain_name = domain.get_domain_name() if isinstance(domain, Domain) else domain
else:
self._domain_name = ""
self._sub_topic_domains = sub_topic_domains
self._pub_topic_domains = pub_topic_domains
# socket information
self._host_addr = ds_host_addr
self._sub_port = sub_port
self._pub_port = pub_port
self._protocol = protocol
self._identifier = identifier
self.debug_logger = debug_logger
self._sub_topics = set()
self._pub_topics = set()
self._publish_sockets = dict()
self._internal_start_topics = dict()
self._internal_end_topics = dict()
self._internal_terminate_topics = dict()
# NOTE: class name + memory pointer make topic unique (required, e.g. for running mutliple instances of same module!)
self._start_topic = f"{type(self).__name__}/{id(self)}/START"
self._end_topic = f"{type(self).__name__}/{id(self)}/END"
self._terminate_topic = f"{type(self).__name__}/{id(self)}/TERMINATE"
self._train_topic = f"{type(self).__name__}/{id(self)}/TRAIN"
self._eval_topic = f"{type(self).__name__}/{id(self)}/EVAL"
def _init_pubsub(self):
""" Search for all functions decorated with the `PublishSubscribe` decorator and call the setup methods for them """
for func_name in dir(self):
func_inst = getattr(self, func_name)
if hasattr(func_inst, "pubsub"):
# found decorated publisher / subscriber function -> setup sockets and listeners
self._setup_listener(func_inst, getattr(func_inst, "sub_topics"),
getattr(func_inst, 'queued_sub_topics'))
self._setup_publishers(func_inst, getattr(func_inst, "pub_topics"))
def _register_with_dialogsystem(self):
""" Start listening to dialog system control channel messages """
self._setup_dialog_ctrl_msg_listener()
Thread(target=self._control_channel_listener).start()
def _setup_listener(self, func_instance, topics: List[str], queued_topics: List[str]):
"""
Starts a new subscription thread for a function decorated with `services.service.PublishSubscribe`.
Args:
func_instance (function): instance of the function that was decorated with `services.service.PublishSubscribe`.
topics (List[str]): list of subscribed topics (drops all but most recent messages before function call)
queued_topics (List[str]): list for subscribed topics (drops no messages, forward a list of received messages to function call)
"""
if len(topics + queued_topics) == 0:
# no subscribed to topics - no need to setup anything (e.g. only publisher)
return
# ensure that sub_topics and queued_sub_topics don't intersect (otherwise, both would set same function argument value)
assert set(topics).isdisjoint(queued_topics), "sub_topics and queued_sub_topics have to be disjoint!"
# setup socket
ctx = Context.instance()
subscriber = ctx.socket(zmq.SUB)
# subscribe to all listed topics
for topic in topics + queued_topics:
topic_domain_str = f"{topic}/{self._domain_name}" if self._domain_name else topic
if topic in self._sub_topic_domains:
# overwrite domain for this specific topic and service instance
topic_domain_str = f"{topic}/{self._sub_topic_domains[topic]}" if self._sub_topic_domains[topic] else topic
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(topic_domain_str, encoding="ascii"))
# subscribe to control channels
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/START", encoding="ascii"))
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/END", encoding="ascii"))
subscriber.setsockopt(zmq.SUBSCRIBE, bytes(f"{func_instance}/TERMINATE", encoding="ascii"))
subscriber.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
self._internal_start_topics[f"{str(func_instance)}/START"] = str(func_instance)
self._internal_end_topics[f"{str(func_instance)}/END"] = str(func_instance)
self._internal_terminate_topics[f"{str(func_instance)}/TERMINATE"] = str(func_instance)
# register and run listener thread
listener_thread = Thread(target=self._receiver_thread, args=(subscriber, func_instance,
topics, queued_topics,
f"{str(func_instance)}/START",
f"{str(func_instance)}/END",
f"{str(func_instance)}/TERMINATE"))
listener_thread.start()
# add to list of local topics
# TODO maybe add topic_domain_str instead for more clarity?
self._sub_topics.update(topics + queued_topics)
def _setup_publishers(self, func_instance, topics):
""" Creates a publish socket for a function decorated with `services.service.PublishSubscribe`. """
if len(topics) == 0:
return # no topics - no need for a socket
# setup publish socket
ctx = Context.instance()
publisher = ctx.socket(zmq.PUB)
publisher.sndhwm = 1100000
publisher.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
self._publish_sockets[func_instance] = publisher
# add to list of local topics
self._pub_topics.update(topics)
def _setup_dialog_ctrl_msg_listener(self):
""" Setup a subscriber socket to receive `DialogSystem` control message """
ctx = Context.instance()
# setup receiver for dialog system control messages
self._control_channel_sub = ctx.socket(zmq.SUB)
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._start_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._end_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._terminate_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._train_topic, encoding="ascii"))
self._control_channel_sub.setsockopt(zmq.SUBSCRIBE, bytes(self._eval_topic, encoding="ascii"))
self._control_channel_sub.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
# setup sender for dialog system control message acknowledgements
self._control_channel_pub = ctx.socket(zmq.PUB)
self._control_channel_pub.sndhwm = 1100000
self._control_channel_pub.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
# setup receiver for internal ACK messages
self._internal_control_channel_sub = ctx.socket(zmq.SUB)
for internal_ctrl_topic in list(self._internal_end_topics.keys()) + list(
self._internal_start_topics.keys()) + list(self._internal_terminate_topics.keys()):
self._internal_control_channel_sub.setsockopt(zmq.SUBSCRIBE,
bytes(f"ACK/{internal_ctrl_topic}", encoding="ascii"))
self._internal_control_channel_sub.connect(f"{self._protocol}://{self._host_addr}:{self._sub_port}")
def _control_channel_listener(self):
""" Using the control message subscription socket, listen to control messages from the `DialogSystem` in a loop.
Meant to be called in a thread.
"""
listen = True
while listen:
try:
# receive message for subscribed control topic
msg = self._control_channel_sub.recv_multipart(copy=True)
topic = msg[0].decode("ascii")
timestamp, content = pickle.loads(msg[1])
if topic == self._start_topic:
# initialize dialog state
self.dialog_start()
# set all listeners of this service to listening mode (block until they are listening)
for internal_start_topic in self._internal_start_topics:
_send_msg(self._control_channel_pub, internal_start_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_start_topic)
_send_ack(self._control_channel_pub, self._start_topic)
elif topic == self._end_topic:
# stop all listeners of this service (block until they stopped)
for internal_end_topic in self._internal_end_topics:
_send_msg(self._control_channel_pub, internal_end_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_end_topic, True)
self.dialog_end()
_send_ack(self._control_channel_pub, self._end_topic)
elif topic == self._terminate_topic:
# terminate all listeners of this service (block until they stopped)
for internal_terminate_topic in self._internal_terminate_topics:
_send_msg(self._control_channel_pub, internal_terminate_topic, True)
_recv_ack(self._internal_control_channel_sub, internal_terminate_topic, True)
self.dialog_exit()
_send_ack(self._control_channel_pub, self._terminate_topic)
listen = False
elif topic == self._train_topic:
self.train()
_send_ack(self._control_channel_pub, self._train_topic)
elif topic == self._eval_topic:
self.eval()
_send_ack(self._control_channel_pub, self._eval_topic)
else:
if self.debug_logger:
self.debug_logger.info("- (Service): received unknown control message from topic", topic,
" with content", content)
except KeyboardInterrupt:
break
except:
import traceback
print("ERROR in Service: _control_channel_listener")
traceback.print_exc()
def dialog_start(self):
""" This function is called before the first message to a new dialog is published.
You should overwrite this function to set/reset dialog-level variables. """
pass
def dialog_end(self):
""" This function is called after a dialog ended (Topics.DIALOG_END message was received).
You should overwrite this function to record dialog-level information. """
pass
def dialog_exit(self):
""" This function is called when the dialog system is shutting down.
You should overwrite this function to stop your threads and cleanup any open resources. """
pass
def train(self):
""" Sets module to training mode """
self.is_training = True
def eval(self):
""" Sets module to eval mode """
self.is_training = False
def run_standalone(self, host_reg_port: int = 65535):
"""
Run this service as a standalone serivce (without a `DialogSystem`) on a remote node.
Use a `RemoteService` with *corresponding identifier* on the `DialogSystem` node to connect both.
Note: this call is blocking!
Args:
host_reg_port (int): The port on the `DialogSystem` node listening for `Service` register requests
"""
assert self._identifier is not None, "running a service on a remote node requires a unique identifier"
print("Waiting for dialog system host...")
# send service info to dialog system node
self._init_pubsub()
ctx = Context.instance()
sync_endpoint = ctx.socket(zmq.REQ)
sync_endpoint.connect(f"tcp://{self._host_addr}:{host_reg_port}")
data = pickle.dumps((self._domain_name, self._sub_topics, self._pub_topics, self._start_topic, self._end_topic,
self._terminate_topic))
sync_endpoint.send_multipart((bytes(f"REGISTER_{self._identifier}", encoding="ascii"), data))
# wait for registration confirmation
registered = False
while not registered:
msg = sync_endpoint.recv()
msg = msg.decode("utf-8")
if msg.startswith("ACK_REGISTER_"):
remote_service_identifier = msg[len("ACK_REGISTER_"):]
if remote_service_identifier == self._identifier:
self._register_with_dialogsystem()
sync_endpoint.send_multipart(
(bytes(f"CONF_REGISTER_{self._identifier}", encoding="ascii"), pickle.dumps(True)))
registered = True
print(f"Done")
def get_all_subscribed_topics(self):
"""
Returns:
Set of all topics subscribed to by this `Service`
"""
return copy.deepcopy(self._sub_topics)
def get_all_published_topics(self):
"""
Returns:
Set of all topics published to by this `Service`
"""
return copy.deepcopy(self._pub_topics)
def _receiver_thread(self, subscriber: Socket, func_instance,
topics: Iterable[str], queued_topics: Iterable[str],
start_topic: str, end_topic: str, terminate_topic: str):
"""
Loop for receiving messages.
Will continue until a message for `terminate_topic` is received.
Handles waiting for messages, decoding, unpickling and subscription topic to
service function keyword mapping.
Meant to be run in a Thread!
Args:
subscriber (Socket): subscriber socket
func_instance (function instance): the decorated subscriber function instance to be called with the received messages
topics (Iterable[str]): all last-message-only topics the decorated `func_instance` subscribes to
queued_topics (Iterable[str]): all collect-all-messages-since-last-call topics the decorated `func_instance` subscribes to
start_topic (str): Control message topic to set this specific `function_instance` into listening mode (receive all non-control messages)
end_topic (str): Control message topic to set this specific `function_instance` into non-listening mode (ignore all non-control messages)
terminate_topic (str): Control message topic to end the listener loop for this specific `function_instance`.
Also closes the socket before returning.
"""
ctx = Context.instance()
control_channel_pub = ctx.socket(zmq.PUB)
control_channel_pub.sndhwm = 1100000
control_channel_pub.connect(f"{self._protocol}://{self._host_addr}:{self._pub_port}")
values = {}
timestamps = {}
all_sub_topics = topics + queued_topics
num_topics = len(all_sub_topics)
active = False
terminating = False
while not terminating:
try:
msg = subscriber.recv_multipart(copy=True)
topic = msg[0].decode("ascii")
# based on topic, decide what to do
if topic == start_topic:
# reset values and start listening to non-control messages
values = {}
timestamps = {}
active = True
_send_ack(control_channel_pub, start_topic)
elif topic == end_topic:
# ignore all non-control messages
active = False
_send_ack(control_channel_pub, end_topic)
elif topic == terminate_topic:
# shutdown listener thread by exiting loop
active = False
_send_ack(control_channel_pub, terminate_topic)
terminating = True
else:
# non-control message
if active:
# process message
timestamp, content = pickle.loads(msg[1])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): listener thread for function {func_instance}:\n received for topic {topic}:\n {content}")
# simple synchronization mechanism: remember only newest values,
# store them until there was at least 1 new value received per topic.
# Then call callback function with complete set of values.
# Reset values afterwards and start collecting again.
# problem: routing based on prefixes -> function argument names may differ
# solution: find longest common prefix of argument name and received topic
common_prefix = ""
for key in all_sub_topics:
if topic.startswith(key) and len(topic) > len(common_prefix):
common_prefix = key
if common_prefix in topics:
# store only latest value
values[common_prefix] = content # set value for received topic
timestamps[common_prefix] = timestamp # set timestamp for received value
else:
# topic is a queued_topic - queue all values and their timestamps
if not common_prefix in values:
values[common_prefix] = []
timestamps[common_prefix] = []
values[common_prefix].append(content)
timestamps[common_prefix].append(timestamp)
if len(values) == num_topics:
# received a new value for each topic -> call callback function
if func_instance.timestamp_enabled:
# append timestamps, if required
values['timestamps'] = timestamps
if self.debug_logger:
self.debug_logger.info(
f"- (DS): received all messages for function {func_instance}\n -> CALLING function")
if self.__class__ == Service:
# NOTE workaround for publisher / subscriber without being an instance method
func_instance(**values)
else:
func_instance(self, **values)
# reset values
values = {}
timestamps = {}
except KeyboardInterrupt:
break
except:
print("THREAD ERROR")
import traceback
traceback.print_exc()
# shutdown
subscriber.close()
__init__(self, domain='', sub_topic_domains={}, pub_topic_domains={}, ds_host_addr='127.0.0.1', sub_port=65533, pub_port=65534, protocol='tcp', debug_logger=None, identifier=None)
special
¶
Create a new service instance (call this super constructor from your inheriting classes!).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Union[str, Domain] |
The domain(-name) of your service (or empty string, if domain-agnostic). If a domain(-name) is set, it will automatically filter out all messages from other domains. If no domain(-name) is set, messages from all domains will be received. |
'' |
sub_topic_domains |
Dict[str, str] |
change subscribed to topics to listen to a specific domain (e.g. 'erase'/append a domain for a specific topic) |
{} |
pub_topic_domains |
Dict[str, str] |
change published topics to a specific domain (e.g. 'erase'/append a domain for a specific topic) |
{} |
ds_host_addr |
str |
IP-address of the parent |
'127.0.0.1' |
sub_port |
int |
subscriber port following zmq's XSUB/XPUB pattern |
65533 |
pub_port |
int |
publisher port following zmq's XSUB/XPUB pattern |
65534 |
protocol |
string |
communication protocol with |
'tcp' |
debug_logger |
DiasysLogger |
If not |
None |
identifier |
str |
Set this to a UNIQUE identifier per service to be run remotely.
See |
None |
Source code in adviser/services/service.py
def __init__(self, domain: Union[str, Domain] = "", sub_topic_domains: Dict[str, str] = {}, pub_topic_domains: Dict[str, str] = {},
ds_host_addr: str = "127.0.0.1", sub_port: int = 65533, pub_port: int = 65534, protocol: str = "tcp",
debug_logger: DiasysLogger = None, identifier: str = None):
"""
Create a new service instance *(call this super constructor from your inheriting classes!)*.
Args:
domain (Union[str, Domain]): The domain(-name) of your service (or empty string, if domain-agnostic).
If a domain(-name) is set, it will automatically filter out all messages from other domains.
If no domain(-name) is set, messages from all domains will be received.
sub_topic_domains (Dict[str, str]): change subscribed to topics to listen to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
pub_topic_domains (Dict[str, str]): change published topics to a specific domain
(e.g. 'erase'/append a domain for a specific topic)
ds_host_addr (str): IP-address of the parent `DialogSystem` (default: localhost)
sub_port (int): subscriber port following zmq's XSUB/XPUB pattern
pub_port (int): publisher port following zmq's XSUB/XPUB pattern
protocol (string): communication protocol with `DialogSystem` - has to match!
Possible options: `tcp`, `inproc`, `ipc`
debug_logger (DiasysLogger): If not `None`, all messags are printed to the logger, including send/receive events.
Can be useful for debugging because you can still see messages received by the `DialogSystem`
even if they are never forwarded (as expected) to your `Service`.
identifier (str): Set this to a *UNIQUE* identifier per service to be run remotely.
See `RemoteService` for more details.
"""
self.is_training = False
self.domain = domain
# get domain name (gets appended to all sub/pub topics so that different domain topics don't get shared)
if domain is not None:
self._domain_name = domain.get_domain_name() if isinstance(domain, Domain) else domain
else:
self._domain_name = ""
self._sub_topic_domains = sub_topic_domains
self._pub_topic_domains = pub_topic_domains
# socket information
self._host_addr = ds_host_addr
self._sub_port = sub_port
self._pub_port = pub_port
self._protocol = protocol
self._identifier = identifier
self.debug_logger = debug_logger
self._sub_topics = set()
self._pub_topics = set()
self._publish_sockets = dict()
self._internal_start_topics = dict()
self._internal_end_topics = dict()
self._internal_terminate_topics = dict()
# NOTE: class name + memory pointer make topic unique (required, e.g. for running mutliple instances of same module!)
self._start_topic = f"{type(self).__name__}/{id(self)}/START"
self._end_topic = f"{type(self).__name__}/{id(self)}/END"
self._terminate_topic = f"{type(self).__name__}/{id(self)}/TERMINATE"
self._train_topic = f"{type(self).__name__}/{id(self)}/TRAIN"
self._eval_topic = f"{type(self).__name__}/{id(self)}/EVAL"
dialog_end(self)
¶
This function is called after a dialog ended (Topics.DIALOG_END message was received). You should overwrite this function to record dialog-level information.
dialog_exit(self)
¶
This function is called when the dialog system is shutting down. You should overwrite this function to stop your threads and cleanup any open resources.
dialog_start(self)
¶
This function is called before the first message to a new dialog is published. You should overwrite this function to set/reset dialog-level variables.
eval(self)
¶
get_all_published_topics(self)
¶
get_all_subscribed_topics(self)
¶
run_standalone(self, host_reg_port=65535)
¶
Run this service as a standalone serivce (without a DialogSystem
) on a remote node.
Use a RemoteService
with corresponding identifier on the DialogSystem
node to connect both.
Note: this call is blocking!
Parameters:
Name | Type | Description | Default |
---|---|---|---|
host_reg_port |
int |
The port on the |
65535 |
Source code in adviser/services/service.py
def run_standalone(self, host_reg_port: int = 65535):
"""
Run this service as a standalone serivce (without a `DialogSystem`) on a remote node.
Use a `RemoteService` with *corresponding identifier* on the `DialogSystem` node to connect both.
Note: this call is blocking!
Args:
host_reg_port (int): The port on the `DialogSystem` node listening for `Service` register requests
"""
assert self._identifier is not None, "running a service on a remote node requires a unique identifier"
print("Waiting for dialog system host...")
# send service info to dialog system node
self._init_pubsub()
ctx = Context.instance()
sync_endpoint = ctx.socket(zmq.REQ)
sync_endpoint.connect(f"tcp://{self._host_addr}:{host_reg_port}")
data = pickle.dumps((self._domain_name, self._sub_topics, self._pub_topics, self._start_topic, self._end_topic,
self._terminate_topic))
sync_endpoint.send_multipart((bytes(f"REGISTER_{self._identifier}", encoding="ascii"), data))
# wait for registration confirmation
registered = False
while not registered:
msg = sync_endpoint.recv()
msg = msg.decode("utf-8")
if msg.startswith("ACK_REGISTER_"):
remote_service_identifier = msg[len("ACK_REGISTER_"):]
if remote_service_identifier == self._identifier:
self._register_with_dialogsystem()
sync_endpoint.send_multipart(
(bytes(f"CONF_REGISTER_{self._identifier}", encoding="ascii"), pickle.dumps(True)))
registered = True
print(f"Done")
train(self)
¶
PublishSubscribe(sub_topics=[], pub_topics=[], queued_sub_topics=[])
¶
Decorator function for services. To be able to publish / subscribe to / from topics, your class is required to inherit from services.service.Service. Then, decorate any function you like.
Your function will be called as soon as: * at least one message is received for each topic in sub_topics (only latest message will be forwarded, others dropped) * at least one message is received for each topic in queued_sub_topics (all messages since the previous function call will be forwarded as a list)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sub_topics(List[str |
or utils.topics.Topic] |
The topics you want to get the latest messages from. If multiple messages are received until your function is called, you will only receive the value of the latest message, previously received values will be discarded. |
required |
pub_topics(List[str |
or utils.topics.Topic] |
The topics you want to publish messages to. |
required |
queued_sub_topics(List[str |
or utils.topics.Topic] |
The topics you want to get all messages from. If multiple messages are received until your function is called, you will receive all values since the previous function call as a list. |
required |
Notes
- Subscription topic names have to match your function keywords
- Your function should return a dictionary with the keys matching your publish topics names and the value being any arbitrary python object or primitive type you want to send
- sub_topics and queued_sub_topics have to be disjoint!
- If you need timestamps for your messages, specify a 'timestamps' argument in your subscribing function. It will be filled by a dictionary providing timestamps for each received value, indexed by name.
Technical notes: * Data will be automatically pickled / unpickled during send / receive to reduce meassage size. However, some python objects are not serializable (e.g. database connections) for good reasons and will throw an error if you try to publish them. * The domain name of your service class will be appended to your publish topics. Subscription topics are prefix-matched, so you will receive all messages from 'topic/suffix' if you subscibe to 'topic'.
Source code in adviser/services/service.py
def PublishSubscribe(sub_topics: List[str] = [], pub_topics: List[str] = [], queued_sub_topics: List[str] = []):
"""
Decorator function for services.
To be able to publish / subscribe to / from topics,
your class is required to inherit from services.service.Service.
Then, decorate any function you like.
Your function will be called as soon as:
* at least one message is received for each topic in sub_topics (only latest message will be forwarded, others dropped)
* at least one message is received for each topic in queued_sub_topics (all messages since the previous function call will be forwarded as a list)
Args:
sub_topics(List[str or utils.topics.Topic]): The topics you want to get the latest messages from.
If multiple messages are received until your function is called,
you will only receive the value of the latest message, previously received
values will be discarded.
pub_topics(List[str or utils.topics.Topic]): The topics you want to publish messages to.
queued_sub_topics(List[str or utils.topics.Topic]): The topics you want to get all messages from.
If multiple messages are received until your function is called,
you will receive all values since the previous function call as a list.
Notes:
* Subscription topic names have to match your function keywords
* Your function should return a dictionary with the keys matching your publish topics names
and the value being any arbitrary python object or primitive type you want to send
* sub_topics and queued_sub_topics have to be disjoint!
* If you need timestamps for your messages, specify a 'timestamps' argument in your subscribing function.
It will be filled by a dictionary providing timestamps for each received value, indexed by name.
Technical notes:
* Data will be automatically pickled / unpickled during send / receive to reduce meassage size.
However, some python objects are not serializable (e.g. database connections) for good reasons
and will throw an error if you try to publish them.
* The domain name of your service class will be appended to your publish topics.
Subscription topics are prefix-matched, so you will receive all messages from 'topic/suffix'
if you subscibe to 'topic'.
"""
def wrapper(func):
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
# declare function as publish / subscribe functions and attach the respective topics
delegate.pubsub = True
delegate.sub_topics = sub_topics
delegate.queued_sub_topics = queued_sub_topics
delegate.pub_topics = pub_topics
# check arguments: is subsriber interested in timestamps?
delegate.timestamp_enabled = 'timestamps' in inspect.getfullargspec(func)[0]
return delegate
return wrapper
simulator
special
¶
This package contains the handcrafted user simulatod and related services.
emotion_simulator
¶
EmotionSimulator (Service)
¶
Class which generates user emotion/engagements. Currently outputs either a user defined or random emotion/engagement level and was designed to test the affective services work correctly. However, in the future it could be extended to be more realistic.
Source code in adviser/services/simulator/emotion_simulator.py
class EmotionSimulator(Service):
"""
Class which generates user emotion/engagements. Currently outputs either a user defined
or random emotion/engagement level and was designed to test the affective services
work correctly. However, in the future it could be extended to be more realistic.
"""
def __init__(self, domain: JSONLookupDomain = None, logger: DiasysLogger = None,
random: bool = True, static_emotion: EmotionType = EmotionType.Neutral,
static_engagement: EngagementType = EngagementType.High):
Service.__init__(self, domain=domain)
self.domain = domain
self.logger = logger
self.random = random
self.engagement = static_engagement
self.emotion = static_emotion
@PublishSubscribe(sub_topics=["user_acts"], pub_topics=["emotion", "engagement"])
def send_emotion(self, user_acts: List[UserAct] = None) -> Dict[str, str]:
"""
Publishes an emotion and engagement value for a turn
Args:
user_acts (List[UserAct]): the useracts, needed to synchronize when emotion should
be generated
Returns:
(dict): A dictionary representing the simulated user emotion and engagement. The keys are
"emotion" and "engagement" where "emotion" is a dictionary which currently only contains
emotion category information but could be expanded to include other emotion measures, and
"engagement" corresponds to and EngagementType object.
"""
if not self.random:
return {"emotion": {"category": self.emotion}, "engagement": self.engagement}
else:
emotion = random.choice([e for e in EmotionType])
engagement = random.choice([e for e in EngagementType])
return {"emotion": {"category": emotion}, "engagement": engagement}
__init__(self, domain=None, logger=None, random=True, static_emotion=<EmotionType.Neutral: 'neutral'>, static_engagement=<EngagementType.High: 'high'>)
special
¶
Source code in adviser/services/simulator/emotion_simulator.py
def __init__(self, domain: JSONLookupDomain = None, logger: DiasysLogger = None,
random: bool = True, static_emotion: EmotionType = EmotionType.Neutral,
static_engagement: EngagementType = EngagementType.High):
Service.__init__(self, domain=domain)
self.domain = domain
self.logger = logger
self.random = random
self.engagement = static_engagement
self.emotion = static_emotion
send_emotion(self, *args, **kwargs)
¶
Source code in adviser/services/simulator/emotion_simulator.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
goal
¶
This module provides the Goal class and related stuff.
Constraint
¶
Source code in adviser/services/simulator/goal.py
class Constraint(object):
def __init__(self, slot, value):
"""
The class for a constraint as used in the goal.
Args:
slot (str): The slot.
value (str): The value.
"""
self.slot = slot
self.value = value
def __eq__(self, other):
"""Constraint should be equal if the slot and value is the same."""
if isinstance(other, Constraint):
return (self.slot == other.slot
and self.value == other.value)
return False
def __getitem__(self, key):
if not isinstance(key, int):
raise TypeError
if key == 0:
return self.slot
elif key == 1:
return self.value
else:
raise IndexError
def __repr__(self):
return "Constraint(slot={}, value={})".format(self.slot, self.value)
def __hash__(self):
return hash(self.slot) * hash(self.value)
__eq__(self, other)
special
¶
Constraint should be equal if the slot and value is the same.
__getitem__(self, key)
special
¶
__hash__(self)
special
¶
__init__(self, slot, value)
special
¶
The class for a constraint as used in the goal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
The slot. |
required |
value |
str |
The value. |
required |
__repr__(self)
special
¶
Goal
¶
Source code in adviser/services/simulator/goal.py
class Goal(object):
def __init__(self, domain: JSONLookupDomain, parameters=None):
"""
The class representing a goal, therefore containing requests and constraints.
Args:
domain (JSONLookupDomain): The domain for which the goal will be instantiated.
It will only work within this domain.
parameters (dict): The parameters for the goal defined by a key=value mapping: 'MinVenues'
(int) allows to set a minimum number of venues which fulfill the constraints of the goal,
'MinConstraints' (int) and 'MaxConstraints' (int) set the minimum and maximum amount of
constraints respectively, 'MinRequests' (int) and 'MaxRequests' (int) set the minimum and
maximum amount of requests respectively and 'Reachable' (float) allows to specify how many
(in percent) of all generated goals are definitely fulfillable (i.e. there exists a venue
for the current goal) or not (doesn't have to be fulfillable). Although the parameter
'Reachable' equals 1.0 implicitly states that 'MinVenues' equals 1 or more, the
implementation looks different, is more efficient and takes all goals into consideration
(since 'Reachable' is a float (percentage of generated goals)). On the other hand, setting
'MinVenues' to any number bigger than 0 forces every goal to be fulfillable.
"""
self.domain = domain
self.parameters = parameters or {}
# cache inform and request slots
# make sure to copy the list (shallow is sufficient)
self.inf_slots = sorted(list(domain.get_informable_slots())[:])
# make sure that primary key is never a constraint
if self.domain.get_primary_key() in self.inf_slots:
self.inf_slots.remove(self.domain.get_primary_key())
# TODO sometimes ask for specific primary key with very small probability (instead of any other constraints?) # pylint: disable=line-too-long
self.inf_slot_values = {}
for slot in self.inf_slots:
self.inf_slot_values[slot] = sorted(
domain.get_possible_values(slot)[:])
self.req_slots = sorted(domain.get_requestable_slots()[:])
# self.req_slots_without_informables = sorted(list(
# set(self.req_slots).difference(self.inf_slots)))
# make sure that primary key is never a request as it is added anyway
if self.domain.get_primary_key() in self.req_slots:
self.req_slots.remove(self.domain.get_primary_key())
self.constraints = []
self.requests = {}
self.excluded_inf_slot_values = {}
self.missing_informs = []
def init(self, random_goal=True, constraints=None, requests=None) -> None:
"""
Initializes a goal randomly OR using the given constraints and requests.
Args:
random_goal (bool): If True, a goal will be drawn randomly from available constraints
and requests (considering the parameters given in the constructor, if any). However if
constraints and requests are given and both don't equal None, this parameter is
considered as False. If False, the given constraints and requests are used.
constraints (List[Constraint]): The constraints which will be used for the goal.
requests (Dict[str, Union[None,str]]): The requests which will be used for the goal.
"""
# reset goal
self.constraints = []
self.requests = {}
self.excluded_inf_slot_values = {key: set()
for key in self.inf_slot_values}
# TODO implement possibility to pass either constraints or requests as a parameter
if random_goal and constraints is None and requests is None:
self._init_random_goal()
else:
self._init_from_parameters(constraints, requests)
# make sure that primary key is always requested
self.requests[self.domain.get_primary_key()] = None
self.missing_informs = [UserAct(act_type=UserActionType.Inform, slot=_constraint.slot, value=_constraint.value)
for _constraint in self.constraints]
def _init_random_goal(self):
"""Randomly sets the constraints and requests for the goal."""
num_venues = -1
# check that there exist at least self.parameters['MinVenues'] venues for this goal
if 'MinVenues' in self.parameters:
min_venues = self.parameters['MinVenues']
else:
min_venues = 0 # default is to not have any lower bound
while num_venues < min_venues:
# TODO exclude 'dontcare' from goal
# TODO make sure that minconstraints and minrequests are a valid number for the current domain # pylint: disable=line-too-long
if 'MaxConstraints' in self.parameters:
num_constraints_max = min(len(self.inf_slots), int(
self.parameters['MaxConstraints']))
else:
# NOTE could become pretty high
num_constraints_max = len(self.inf_slots)
if 'MinConstraints' in self.parameters:
num_constraints_min = int(self.parameters['MinConstraints'])
else:
num_constraints_min = 1
# draw constraints uniformly
num_constraints = common.random.randint(
num_constraints_min, num_constraints_max)
constraint_slots = common.numpy.random.choice(
self.inf_slots, num_constraints, replace=False)
self.constraints = []
if ('Reachable' in self.parameters
and common.random.random() < self.parameters['Reachable']):
# pick entity from database and set constraints
results = self.domain.find_entities(
constraints={}, requested_slots=constraint_slots.tolist())
assert results, "Cannot receive entity from database,\
probably because the database is empty."
entity = common.random.choice(results)
for constraint in constraint_slots:
self.constraints.append(Constraint(
constraint, entity[constraint]))
else:
# pick random constraints
for constraint in constraint_slots:
self.constraints.append(Constraint(constraint, common.numpy.random.choice(
self.inf_slot_values[constraint], size=1)[0]))
# check if there are enough venues for the current goal
num_venues = len(self.domain.find_entities(constraints={
constraint.slot: constraint.value for constraint in self.constraints}))
possible_req_slots = sorted(
list(set(self.req_slots).difference(constraint_slots)))
if 'MaxRequests' in self.parameters:
num_requests_max = min(len(possible_req_slots), int(
self.parameters['MaxRequests']))
else:
# NOTE could become pretty high
num_requests_max = len(possible_req_slots)
if 'MinRequests' in self.parameters:
num_requests_min = int(self.parameters['MinRequests'])
else:
num_requests_min = 0 # primary key is included anyway
num_requests = common.random.randint(
num_requests_min, num_requests_max)
self.requests = {slot: None for slot in common.numpy.random.choice(
possible_req_slots, num_requests, replace=False)}
# print(self.requests)
# print(self.constraints)
# TODO add some remaining informable slot as request with some probability
# add_req_slots_candidates = list(set(self.inf_slots).difference(constraint_slots))
def _init_from_parameters(self, constraints, requests):
"""Converts the given constraints and requests to the goal."""
# initialise goal with given constraints and requests
if isinstance(constraints, list):
if constraints:
if isinstance(constraints[0], Constraint):
self.constraints = copy.deepcopy(constraints)
if isinstance(constraints[0], tuple):
# assume tuples in list with strings (slot, value)
self.constraints = [Constraint(
slot, value) for slot, value in constraints]
elif isinstance(constraints, dict):
self.constraints = [Constraint(slot, value)
for slot, value in constraints.items()]
else:
raise ValueError(
"Given constraints for goal must be of type list or dict.")
if not isinstance(requests, dict):
if isinstance(requests, list):
# assume list of strings
self.requests = dict.fromkeys(requests, None)
else:
self.requests = requests
num_venues = len(self.domain.find_entities(constraints={
constraint.slot: constraint.value for constraint in self.constraints}))
if 'MinVenues' in self.parameters:
assert num_venues >= self.parameters['MinVenues'], "There are not enough venues for\
the given constraints in the database. Either change constraints or lower\
parameter Goal:MinVenues."
def reset(self):
"""Resets all requests of the goal."""
# reset goal -> empty all requests
self.requests = dict.fromkeys(self.requests)
# for slot in self.requests:
# self.requests[slot] = None
def __repr__(self):
return "Goal(constraints={}, requests={})".format(self.constraints, self.requests)
# NOTE only checks if requests are fulfilled,
# not whether the result from the system conforms to the constraints
def is_fulfilled(self):
"""
Checks whether all requests have been fulfilled.
Returns:
bool: ``True`` if all requests have been fulfilled, ``False`` otherwise.
.. note:: Does not check whether the venue (issued by the system) fulfills the constraints
since it's the system's task to give an appropriate venue by requesting the user's
constraints.
"""
for slot, value in self.requests.items():
assert slot != self.domain.get_primary_key() or value != 'none' # TODO remove later
if value is None:
return False
return True
def fulfill_request(self, slot, value):
"""
Fulfills a request, i.e. sets ``value`` for request ``slot``.
Args:
slot (str): The request slot which will be filled.
value (str): The value the request slot will be filled with.
"""
if slot in self.requests:
self.requests[slot] = value
# does not consider 'dontcare's
# NOTE better use is_inconsistent_constraint or is_inconsistent_constraint_strict
# def contains_constraint(self, constraint):
# if constraint in self.constraints:
# return True
# return False
# constraint is consistent with goal if values match or value in goal is 'dontcare'
def is_inconsistent_constraint(self, constraint):
"""
Checks whether the given constraint is consistent with the goal. A constraint is also
consistent if it's value is 'dontcare' in the current goal.
Args:
constraint (Constraint): The constraint which will be checked for consistency.
Returns:
bool: True if values match or value in goal is 'dontcare', False otherwise.
"""
for _constraint in self.constraints:
if _constraint.slot == constraint.slot and (_constraint.value != constraint.value \
and _constraint.value != 'dontcare'):
return True
return False
# constraint is consistent with goal if values match
# ('dontcare' is considered as different value)
def is_inconsistent_constraint_strict(self, constraint):
"""
Checks whether the given constraint is strictly consistent with the goal, whereby
'dontcare' is treated as a different value (no match).
Args:
constraint (Constraint): The constraint which will be checked for consistency.
Returns:
bool: True if values match, False otherwise.
!!! seealso "See Also"
[`is_inconsistent_constraint`][adviser.services.simulator.goal.Goal.is_inconsistent_constraint]
"""
for _constraint in self.constraints:
if _constraint.slot == constraint.slot and _constraint.value == constraint.value:
return False
# here there are only two possibilities: the constraint is implicitly 'dontcare' because
# it is not explicitly listed and the given constraint is either 1) 'dontcare' or 2) not
return constraint.value != 'dontcare'
def get_constraint(self, slot):
"""
Gets the value for a given constraint ``slot``.
Args:
slot (str): The constraint ``slot`` which will be looked up.
Returns:
bool: The constraint ``value``.
"""
for _constraint in self.constraints:
if _constraint.slot == slot:
return _constraint.value
return 'dontcare'
# update the constraint with the slot 'slot' with 'value', assuming the constraints are unique
def update_constraint(self, slot, value):
"""
Update a given constraint ``slot`` with ``value``.
Args:
slot (str): The constraint *slot* which will be updated.
value (str): The *value* with which the constraint will be updated.
Returns:
bool: ``True`` if update was successful, i.e. the constraint ``slot`` is included in
the goal, ``False`` otherwise.
"""
for _constraint in self.constraints:
if _constraint.slot == slot:
_constraint.value = value
return True
return False
__init__(self, domain, parameters=None)
special
¶
The class representing a goal, therefore containing requests and constraints.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
JSONLookupDomain |
The domain for which the goal will be instantiated. It will only work within this domain. |
required |
parameters |
dict |
The parameters for the goal defined by a key=value mapping: 'MinVenues' (int) allows to set a minimum number of venues which fulfill the constraints of the goal, 'MinConstraints' (int) and 'MaxConstraints' (int) set the minimum and maximum amount of constraints respectively, 'MinRequests' (int) and 'MaxRequests' (int) set the minimum and maximum amount of requests respectively and 'Reachable' (float) allows to specify how many (in percent) of all generated goals are definitely fulfillable (i.e. there exists a venue for the current goal) or not (doesn't have to be fulfillable). Although the parameter 'Reachable' equals 1.0 implicitly states that 'MinVenues' equals 1 or more, the implementation looks different, is more efficient and takes all goals into consideration (since 'Reachable' is a float (percentage of generated goals)). On the other hand, setting 'MinVenues' to any number bigger than 0 forces every goal to be fulfillable. |
None |
Source code in adviser/services/simulator/goal.py
def __init__(self, domain: JSONLookupDomain, parameters=None):
"""
The class representing a goal, therefore containing requests and constraints.
Args:
domain (JSONLookupDomain): The domain for which the goal will be instantiated.
It will only work within this domain.
parameters (dict): The parameters for the goal defined by a key=value mapping: 'MinVenues'
(int) allows to set a minimum number of venues which fulfill the constraints of the goal,
'MinConstraints' (int) and 'MaxConstraints' (int) set the minimum and maximum amount of
constraints respectively, 'MinRequests' (int) and 'MaxRequests' (int) set the minimum and
maximum amount of requests respectively and 'Reachable' (float) allows to specify how many
(in percent) of all generated goals are definitely fulfillable (i.e. there exists a venue
for the current goal) or not (doesn't have to be fulfillable). Although the parameter
'Reachable' equals 1.0 implicitly states that 'MinVenues' equals 1 or more, the
implementation looks different, is more efficient and takes all goals into consideration
(since 'Reachable' is a float (percentage of generated goals)). On the other hand, setting
'MinVenues' to any number bigger than 0 forces every goal to be fulfillable.
"""
self.domain = domain
self.parameters = parameters or {}
# cache inform and request slots
# make sure to copy the list (shallow is sufficient)
self.inf_slots = sorted(list(domain.get_informable_slots())[:])
# make sure that primary key is never a constraint
if self.domain.get_primary_key() in self.inf_slots:
self.inf_slots.remove(self.domain.get_primary_key())
# TODO sometimes ask for specific primary key with very small probability (instead of any other constraints?) # pylint: disable=line-too-long
self.inf_slot_values = {}
for slot in self.inf_slots:
self.inf_slot_values[slot] = sorted(
domain.get_possible_values(slot)[:])
self.req_slots = sorted(domain.get_requestable_slots()[:])
# self.req_slots_without_informables = sorted(list(
# set(self.req_slots).difference(self.inf_slots)))
# make sure that primary key is never a request as it is added anyway
if self.domain.get_primary_key() in self.req_slots:
self.req_slots.remove(self.domain.get_primary_key())
self.constraints = []
self.requests = {}
self.excluded_inf_slot_values = {}
self.missing_informs = []
__repr__(self)
special
¶
fulfill_request(self, slot, value)
¶
Fulfills a request, i.e. sets value
for request slot
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
The request slot which will be filled. |
required |
value |
str |
The value the request slot will be filled with. |
required |
Source code in adviser/services/simulator/goal.py
get_constraint(self, slot)
¶
Gets the value for a given constraint slot
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
The constraint |
required |
Returns:
Type | Description |
---|---|
bool |
The constraint |
Source code in adviser/services/simulator/goal.py
def get_constraint(self, slot):
"""
Gets the value for a given constraint ``slot``.
Args:
slot (str): The constraint ``slot`` which will be looked up.
Returns:
bool: The constraint ``value``.
"""
for _constraint in self.constraints:
if _constraint.slot == slot:
return _constraint.value
return 'dontcare'
init(self, random_goal=True, constraints=None, requests=None)
¶
Initializes a goal randomly OR using the given constraints and requests.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
random_goal |
bool |
If True, a goal will be drawn randomly from available constraints and requests (considering the parameters given in the constructor, if any). However if constraints and requests are given and both don't equal None, this parameter is considered as False. If False, the given constraints and requests are used. |
True |
constraints |
List[Constraint] |
The constraints which will be used for the goal. |
None |
requests |
Dict[str, Union[None,str]] |
The requests which will be used for the goal. |
None |
Source code in adviser/services/simulator/goal.py
def init(self, random_goal=True, constraints=None, requests=None) -> None:
"""
Initializes a goal randomly OR using the given constraints and requests.
Args:
random_goal (bool): If True, a goal will be drawn randomly from available constraints
and requests (considering the parameters given in the constructor, if any). However if
constraints and requests are given and both don't equal None, this parameter is
considered as False. If False, the given constraints and requests are used.
constraints (List[Constraint]): The constraints which will be used for the goal.
requests (Dict[str, Union[None,str]]): The requests which will be used for the goal.
"""
# reset goal
self.constraints = []
self.requests = {}
self.excluded_inf_slot_values = {key: set()
for key in self.inf_slot_values}
# TODO implement possibility to pass either constraints or requests as a parameter
if random_goal and constraints is None and requests is None:
self._init_random_goal()
else:
self._init_from_parameters(constraints, requests)
# make sure that primary key is always requested
self.requests[self.domain.get_primary_key()] = None
self.missing_informs = [UserAct(act_type=UserActionType.Inform, slot=_constraint.slot, value=_constraint.value)
for _constraint in self.constraints]
is_fulfilled(self)
¶
Checks whether all requests have been fulfilled.
Returns:
Type | Description |
---|---|
bool |
|
.. note:: Does not check whether the venue (issued by the system) fulfills the constraints since it's the system's task to give an appropriate venue by requesting the user's constraints.
Source code in adviser/services/simulator/goal.py
def is_fulfilled(self):
"""
Checks whether all requests have been fulfilled.
Returns:
bool: ``True`` if all requests have been fulfilled, ``False`` otherwise.
.. note:: Does not check whether the venue (issued by the system) fulfills the constraints
since it's the system's task to give an appropriate venue by requesting the user's
constraints.
"""
for slot, value in self.requests.items():
assert slot != self.domain.get_primary_key() or value != 'none' # TODO remove later
if value is None:
return False
return True
is_inconsistent_constraint(self, constraint)
¶
Checks whether the given constraint is consistent with the goal. A constraint is also consistent if it's value is 'dontcare' in the current goal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constraint |
Constraint |
The constraint which will be checked for consistency. |
required |
Returns:
Type | Description |
---|---|
bool |
True if values match or value in goal is 'dontcare', False otherwise. |
Source code in adviser/services/simulator/goal.py
def is_inconsistent_constraint(self, constraint):
"""
Checks whether the given constraint is consistent with the goal. A constraint is also
consistent if it's value is 'dontcare' in the current goal.
Args:
constraint (Constraint): The constraint which will be checked for consistency.
Returns:
bool: True if values match or value in goal is 'dontcare', False otherwise.
"""
for _constraint in self.constraints:
if _constraint.slot == constraint.slot and (_constraint.value != constraint.value \
and _constraint.value != 'dontcare'):
return True
return False
is_inconsistent_constraint_strict(self, constraint)
¶
Checks whether the given constraint is strictly consistent with the goal, whereby 'dontcare' is treated as a different value (no match).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
constraint |
Constraint |
The constraint which will be checked for consistency. |
required |
Returns:
Type | Description |
---|---|
bool |
True if values match, False otherwise. |
See Also
Source code in adviser/services/simulator/goal.py
def is_inconsistent_constraint_strict(self, constraint):
"""
Checks whether the given constraint is strictly consistent with the goal, whereby
'dontcare' is treated as a different value (no match).
Args:
constraint (Constraint): The constraint which will be checked for consistency.
Returns:
bool: True if values match, False otherwise.
!!! seealso "See Also"
[`is_inconsistent_constraint`][adviser.services.simulator.goal.Goal.is_inconsistent_constraint]
"""
for _constraint in self.constraints:
if _constraint.slot == constraint.slot and _constraint.value == constraint.value:
return False
# here there are only two possibilities: the constraint is implicitly 'dontcare' because
# it is not explicitly listed and the given constraint is either 1) 'dontcare' or 2) not
return constraint.value != 'dontcare'
reset(self)
¶
update_constraint(self, slot, value)
¶
Update a given constraint slot
with value
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
slot |
str |
The constraint slot which will be updated. |
required |
value |
str |
The value with which the constraint will be updated. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in adviser/services/simulator/goal.py
def update_constraint(self, slot, value):
"""
Update a given constraint ``slot`` with ``value``.
Args:
slot (str): The constraint *slot* which will be updated.
value (str): The *value* with which the constraint will be updated.
Returns:
bool: ``True`` if update was successful, i.e. the constraint ``slot`` is included in
the goal, ``False`` otherwise.
"""
for _constraint in self.constraints:
if _constraint.slot == slot:
_constraint.value = value
return True
return False
simulator
¶
This module provides the agenda-based user model for the handcrafted simulator.
Agenda
¶
A stack-like object representing an agenda. Actions can be pushed on and popped off the agenda.
Source code in adviser/services/simulator/simulator.py
class Agenda(object):
"""
A stack-like object representing an agenda. Actions can be pushed on and popped off the agenda.
"""
def __init__(self):
self.stack = []
def __iter__(self):
return iter(self.stack)
def __contains__(self, value):
return value in self.stack
def __len__(self):
return len(self.stack)
def __bool__(self):
return bool(self.stack)
def __repr__(self):
return repr(self.stack)
def __str__(self):
return str(self.stack)
def init(self, goal):
"""
Initializes the agenda given a goal. For this purpose, inform actions for constraints in
the goal and request actions for requests in the goal are added such that the informs are
handled first followed by the requests.
Args:
goal (Goal): The goal for which the agenda will be initialized.
"""
self.stack.clear()
# populate agenda according to goal
# NOTE don't push bye action here since bye action could be poppped with another (missing)
# request, but user should not end dialog before having the goal fulfilled
# NOTE do not add requests to agenda since system can't handle inform and request action in
# same turn currently!
# self.fill_with_requests(goal)
self.fill_with_constraints(goal)
def push(self, item):
"""Pushes *item* onto the agenda.
Args:
item: The goal for which the agenda will be initialized.
"""
if isinstance(item, list):
self.stack += item
else:
self.stack.append(item)
def get_actions(self, num_actions: int):
"""Retrieves *num_actions* actions from the agenda.
Args:
num_actions (int): Amount of actions which will be retrieved from the agenda.
Returns:
(List[UserAct]): list of *num_actions* user actions.
"""
if num_actions < 0 or num_actions > len(self.stack):
num_actions = len(self.stack)
return [self.stack.pop() for _ in range(num_actions)]
def clean(self, goal: Goal):
"""Cleans the agenda, i.e. makes sure that actions are consistent with goal and in the
correct order.
Args:
goal (Goal): The goal which is needed to determine the consistent actions.
"""
cleaned_stack = []
# reverse order since most recent actions are on top of agenda
for action in self.stack[::-1]:
if action not in cleaned_stack:
# NOTE sufficient if there is only one slot per (request) action
# remove accomplished requests
if (action.type is not UserActionType.Request
or (action.slot in goal.requests and goal.requests[action.slot] is None)
or action.slot not in goal.requests):
# make sure to remove "old" inform actions
if action.type is UserActionType.Inform:
if not goal.is_inconsistent_constraint(
Constraint(action.slot, action.value)):
cleaned_stack.insert(0, action)
else:
cleaned_stack.insert(0, action)
self.stack = cleaned_stack
def clear(self):
"""Empties the agenda."""
self.stack.clear()
def is_empty(self):
"""Checks whether the agenda is empty.
Returns:
(bool): True if agenda is empty, False otherwise.
"""
return len(self.stack) == 0
def contains_action_of_type(self, act_type: UserActionType, consider_dontcare=True):
"""Checks whether agenda contains actions of a specific type.
Args:
act_type (UserActionType): The action type (intent) for which the agenda will be checked.
consider_dontcare (bool): If set to True also considers actions for which the value is
'dontcare', and ignores them otherwise.
Returns:
(bool): True if agenda contains *act_type*, False otherwise.
"""
for _action in self.stack:
if not consider_dontcare and _action.value == 'dontcare':
continue
if _action.type == act_type:
return True
return False
def get_actions_of_type(self, act_type: UserActionType, consider_dontcare: bool = True):
"""Get actions of a specific type from the agenda.
Args:
act_type (UserActionType): The action type (intent) for which the agenda will be checked.
consider_dontcare (bool): If set to True also considers actions for which the value is
'dontcare', and ignores them otherwise.
Returns:
(Iterable[UserAct]): A list of user actions of the given type/intent.
"""
return filter(
lambda x: x.type == act_type
and (consider_dontcare or x.value != 'dontcare'), self.stack)
def remove_actions_of_type(self, act_type: UserActionType):
"""Removes actions of a specific type from the agenda.
Args:
act_type (UserActionType): The action type (intent) which will be removed from the agenda.
"""
self.stack = list(filter(lambda x: x.type != act_type, self.stack))
def remove_actions(self, act_type: UserActionType, slot: str, value: str = None):
"""Removes actions of a specific type, slot and optionally value from the agenda. All
arguments (value only if given) have to match in conjunction.
Args:
act_type (UserActionType): The action type (intent) which will be removed from the agenda.
slot (str): The action type (intent) which will be removed from the agenda.
value (str): The action type (intent) which will be removed from the agenda.
"""
if value is None:
self.stack = list(filter(lambda x: x.type != act_type or x.slot != slot, self.stack))
else:
self.stack = list(filter(
lambda x: x.type != act_type or x.slot != slot or x.value != value, self.stack))
def fill_with_requests(self, goal: Goal, exclude_name: bool = True):
"""Adds all request actions to the agenda necessary to fulfill the *goal*.
Args:
goal (Goal): The current goal of the (simulated) user for which actions will be pushed to the
agenda.
exclude_name (bool): whehter or not to include an action to request an entities name.
"""
# add requests and make sure to add the name at the end (i.e. ask first for name)
for key, value in goal.requests.items():
if ((key != 'name' and exclude_name) or not exclude_name) and value is None:
self.stack.append(
UserAct(act_type=UserActionType.Request, slot=key, value=value, score=1.0))
def fill_with_constraints(self, goal: Goal):
"""
Adds all inform actions to the agenda necessary to fulfill the *goal*. Generally there is
no need to add all constraints from the goal to the agenda apart from the initialisation.
Args:
goal (Goal): The current goal of the (simulated) user for which actions will be pushed to the agenda.
"""
# add informs from goal
for constraint in goal.constraints:
self.stack.append(UserAct(
act_type=UserActionType.Inform,
slot=constraint.slot,
value=constraint.value, score=1.0))
__bool__(self)
special
¶
__contains__(self, value)
special
¶
__init__(self)
special
¶
__iter__(self)
special
¶
__len__(self)
special
¶
__repr__(self)
special
¶
__str__(self)
special
¶
clean(self, goal)
¶
Cleans the agenda, i.e. makes sure that actions are consistent with goal and in the correct order.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
goal |
Goal |
The goal which is needed to determine the consistent actions. |
required |
Source code in adviser/services/simulator/simulator.py
def clean(self, goal: Goal):
"""Cleans the agenda, i.e. makes sure that actions are consistent with goal and in the
correct order.
Args:
goal (Goal): The goal which is needed to determine the consistent actions.
"""
cleaned_stack = []
# reverse order since most recent actions are on top of agenda
for action in self.stack[::-1]:
if action not in cleaned_stack:
# NOTE sufficient if there is only one slot per (request) action
# remove accomplished requests
if (action.type is not UserActionType.Request
or (action.slot in goal.requests and goal.requests[action.slot] is None)
or action.slot not in goal.requests):
# make sure to remove "old" inform actions
if action.type is UserActionType.Inform:
if not goal.is_inconsistent_constraint(
Constraint(action.slot, action.value)):
cleaned_stack.insert(0, action)
else:
cleaned_stack.insert(0, action)
self.stack = cleaned_stack
clear(self)
¶
contains_action_of_type(self, act_type, consider_dontcare=True)
¶
Checks whether agenda contains actions of a specific type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
act_type |
UserActionType |
The action type (intent) for which the agenda will be checked. |
required |
consider_dontcare |
bool |
If set to True also considers actions for which the value is 'dontcare', and ignores them otherwise. |
True |
Returns:
Type | Description |
---|---|
(bool) |
True if agenda contains act_type, False otherwise. |
Source code in adviser/services/simulator/simulator.py
def contains_action_of_type(self, act_type: UserActionType, consider_dontcare=True):
"""Checks whether agenda contains actions of a specific type.
Args:
act_type (UserActionType): The action type (intent) for which the agenda will be checked.
consider_dontcare (bool): If set to True also considers actions for which the value is
'dontcare', and ignores them otherwise.
Returns:
(bool): True if agenda contains *act_type*, False otherwise.
"""
for _action in self.stack:
if not consider_dontcare and _action.value == 'dontcare':
continue
if _action.type == act_type:
return True
return False
fill_with_constraints(self, goal)
¶
Adds all inform actions to the agenda necessary to fulfill the goal. Generally there is no need to add all constraints from the goal to the agenda apart from the initialisation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
goal |
Goal |
The current goal of the (simulated) user for which actions will be pushed to the agenda. |
required |
Source code in adviser/services/simulator/simulator.py
def fill_with_constraints(self, goal: Goal):
"""
Adds all inform actions to the agenda necessary to fulfill the *goal*. Generally there is
no need to add all constraints from the goal to the agenda apart from the initialisation.
Args:
goal (Goal): The current goal of the (simulated) user for which actions will be pushed to the agenda.
"""
# add informs from goal
for constraint in goal.constraints:
self.stack.append(UserAct(
act_type=UserActionType.Inform,
slot=constraint.slot,
value=constraint.value, score=1.0))
fill_with_requests(self, goal, exclude_name=True)
¶
Adds all request actions to the agenda necessary to fulfill the goal.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
goal |
Goal |
The current goal of the (simulated) user for which actions will be pushed to the agenda. |
required |
exclude_name |
bool |
whehter or not to include an action to request an entities name. |
True |
Source code in adviser/services/simulator/simulator.py
def fill_with_requests(self, goal: Goal, exclude_name: bool = True):
"""Adds all request actions to the agenda necessary to fulfill the *goal*.
Args:
goal (Goal): The current goal of the (simulated) user for which actions will be pushed to the
agenda.
exclude_name (bool): whehter or not to include an action to request an entities name.
"""
# add requests and make sure to add the name at the end (i.e. ask first for name)
for key, value in goal.requests.items():
if ((key != 'name' and exclude_name) or not exclude_name) and value is None:
self.stack.append(
UserAct(act_type=UserActionType.Request, slot=key, value=value, score=1.0))
get_actions(self, num_actions)
¶
Retrieves num_actions actions from the agenda.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_actions |
int |
Amount of actions which will be retrieved from the agenda. |
required |
Returns:
Type | Description |
---|---|
(List[UserAct]) |
list of num_actions user actions. |
Source code in adviser/services/simulator/simulator.py
def get_actions(self, num_actions: int):
"""Retrieves *num_actions* actions from the agenda.
Args:
num_actions (int): Amount of actions which will be retrieved from the agenda.
Returns:
(List[UserAct]): list of *num_actions* user actions.
"""
if num_actions < 0 or num_actions > len(self.stack):
num_actions = len(self.stack)
return [self.stack.pop() for _ in range(num_actions)]
get_actions_of_type(self, act_type, consider_dontcare=True)
¶
Get actions of a specific type from the agenda.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
act_type |
UserActionType |
The action type (intent) for which the agenda will be checked. |
required |
consider_dontcare |
bool |
If set to True also considers actions for which the value is 'dontcare', and ignores them otherwise. |
True |
Returns:
Type | Description |
---|---|
(Iterable[UserAct]) |
A list of user actions of the given type/intent. |
Source code in adviser/services/simulator/simulator.py
def get_actions_of_type(self, act_type: UserActionType, consider_dontcare: bool = True):
"""Get actions of a specific type from the agenda.
Args:
act_type (UserActionType): The action type (intent) for which the agenda will be checked.
consider_dontcare (bool): If set to True also considers actions for which the value is
'dontcare', and ignores them otherwise.
Returns:
(Iterable[UserAct]): A list of user actions of the given type/intent.
"""
return filter(
lambda x: x.type == act_type
and (consider_dontcare or x.value != 'dontcare'), self.stack)
init(self, goal)
¶
Initializes the agenda given a goal. For this purpose, inform actions for constraints in the goal and request actions for requests in the goal are added such that the informs are handled first followed by the requests.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
goal |
Goal |
The goal for which the agenda will be initialized. |
required |
Source code in adviser/services/simulator/simulator.py
def init(self, goal):
"""
Initializes the agenda given a goal. For this purpose, inform actions for constraints in
the goal and request actions for requests in the goal are added such that the informs are
handled first followed by the requests.
Args:
goal (Goal): The goal for which the agenda will be initialized.
"""
self.stack.clear()
# populate agenda according to goal
# NOTE don't push bye action here since bye action could be poppped with another (missing)
# request, but user should not end dialog before having the goal fulfilled
# NOTE do not add requests to agenda since system can't handle inform and request action in
# same turn currently!
# self.fill_with_requests(goal)
self.fill_with_constraints(goal)
is_empty(self)
¶
Checks whether the agenda is empty.
Returns:
Type | Description |
---|---|
(bool) |
True if agenda is empty, False otherwise. |
push(self, item)
¶
Pushes item onto the agenda.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
item |
The goal for which the agenda will be initialized. |
required |
remove_actions(self, act_type, slot, value=None)
¶
Removes actions of a specific type, slot and optionally value from the agenda. All arguments (value only if given) have to match in conjunction.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
act_type |
UserActionType |
The action type (intent) which will be removed from the agenda. |
required |
slot |
str |
The action type (intent) which will be removed from the agenda. |
required |
value |
str |
The action type (intent) which will be removed from the agenda. |
None |
Source code in adviser/services/simulator/simulator.py
def remove_actions(self, act_type: UserActionType, slot: str, value: str = None):
"""Removes actions of a specific type, slot and optionally value from the agenda. All
arguments (value only if given) have to match in conjunction.
Args:
act_type (UserActionType): The action type (intent) which will be removed from the agenda.
slot (str): The action type (intent) which will be removed from the agenda.
value (str): The action type (intent) which will be removed from the agenda.
"""
if value is None:
self.stack = list(filter(lambda x: x.type != act_type or x.slot != slot, self.stack))
else:
self.stack = list(filter(
lambda x: x.type != act_type or x.slot != slot or x.value != value, self.stack))
remove_actions_of_type(self, act_type)
¶
Removes actions of a specific type from the agenda.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
act_type |
UserActionType |
The action type (intent) which will be removed from the agenda. |
required |
Source code in adviser/services/simulator/simulator.py
HandcraftedUserSimulator (Service)
¶
The class for a handcrafted (agenda-based) user simulator.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
domain |
Domain |
The domain for which the user simulator will be instantiated. It will use |
required |
Source code in adviser/services/simulator/simulator.py
class HandcraftedUserSimulator(Service):
"""The class for a handcrafted (agenda-based) user simulator.
Args:
domain (Domain): The domain for which the user simulator will be instantiated. It will use
this domain to generate the goals.
"""
def __init__(self, domain: Domain, logger: DiasysLogger = DiasysLogger()):
super(HandcraftedUserSimulator, self).__init__(domain)
# possible system actions
self.receive_options = {SysActionType.Welcome: self._receive_welcome,
SysActionType.InformByName: self._receive_informbyname,
SysActionType.InformByAlternatives:
self._receive_informbyalternatives,
SysActionType.Request: self._receive_request,
SysActionType.Confirm: self._receive_confirm,
SysActionType.Select: self._receive_select,
SysActionType.RequestMore: self._receive_requestmore,
SysActionType.Bad: self._receive_bad,
SysActionType.ConfirmRequest: self._receive_confirmrequest}
# parse config file
self.logger = logger
self.config = configparser.ConfigParser(
inline_comment_prefixes=('#', ';'))
self.config.optionxform = str
self.config.read(os.path.join(os.path.abspath(
os.path.dirname(__file__)), 'usermodel.cfg'))
self.parameters = {}
# goal
self.parameters['goal'] = {}
for key in self.config["goal"]:
val = self.config.get("goal", key)
self.parameters['goal'][key] = float(val)
# usermodel
self.parameters['usermodel'] = {}
for key in self.config["usermodel"]:
val = self.config.get(
"usermodel", key)
if key in ['patience']:
# patience will be sampled on begin of each dialog
self.parameters['usermodel'][key] = [int(x) for x in (
val.replace(' ', '').strip('[]').split(','))]
else:
if val.startswith("[") and val.endswith("]"):
# value is a list to sample the probability from
self.parameters['usermodel'][key] = common.numpy.random.uniform(
*[float(x) for x in val.replace(' ', '').strip('[]').split(',')])
else:
# value is the probability
self.parameters['usermodel'][key] = float(val)
# member declarations
self.turn = 0
self.domain = domain
self.dialog_patience = None
self.patience = None
self.last_user_actions = None
self.last_system_action = None
self.excluded_venues = []
# member definitions
self.goal = Goal(domain, self.parameters['goal'])
self.agenda = Agenda()
self.num_actions_next_turn = -1
def dialog_start(self):
"""Resets the user model at the beginning of a dialog, e.g. draws a new goal and populates
the agenda according to the goal."""
# self.goal = Goal(self.domain, self.parameters['goal'])
self.goal.init()
self.agenda.init(self.goal)
if self.logger:
self.logger.dialog_turn(
"New goal has constraints {} and requests {}.".format(
self.goal.constraints, self.goal.requests))
self.logger.dialog_turn("New agenda initialized: {}".format(self.agenda))
# add hello action with some probability
if common.random.random() < self.parameters['usermodel']['Greeting']:
self.agenda.push(UserAct(act_type=UserActionType.Hello, score=1.0))
# needed for possibility to reset patience
if len(self.parameters['usermodel']['patience']) == 1:
self.dialog_patience = self.parameters['usermodel']['patience'][0]
else:
self.dialog_patience = common.random.randint(
*self.parameters['usermodel']['patience'])
self.patience = self.dialog_patience
self.last_user_actions = None
self.last_system_action = None
self.excluded_venues = []
self.turn = 0
@PublishSubscribe(sub_topics=["sys_act", "sys_turn_over"], pub_topics=["user_acts", "sim_goal"])
def user_turn(self, sys_act: SysAct = None, sys_turn_over=False) \
-> dict(user_acts=List[UserAct], sim_goal=Goal):
"""
Determines the next user actions based on the given system actions and the user simulator's own goal
Args:
sys_act (SysAct): The system action for which a user response will be retrieved.
sys_turn_over (bool): signal to start the user turn
Returns:
(dict): Dictionary including the user acts as a list and the current user's goal.
"""
# self.turn = dialog_graph.num_turns
if sys_act is not None and sys_act.type == SysActionType.Bye:
# if self.goal.is_fulfilled():
# self._finish_dialog()
return {"sim_goal": self.goal}
if sys_act is not None:
self.receive(sys_act)
user_acts = self.respond()
# user_acts = [UserAct(text="Hi!", act_type=UserActionType.Hello, score=1.)]
self.logger.dialog_turn("User Action: " + str(user_acts))
# input()
return {'user_acts': user_acts}
def receive(self, sys_act: SysAct):
"""
This function makes sure that the agenda reflects all changes needed for the received
system action.
Args:
sys_act (SysAct): The action the system took
"""
if self.last_system_action is not None:
# check whether system action is the same as before
if sys_act == self.last_system_action:
self.patience -= 1
elif self.parameters['usermodel']['resetPatience']:
self.patience = self.dialog_patience
self.last_system_action = sys_act
if self.patience == 0:
self.logger.dialog_turn("User patience run out, ending dialog.")
self.agenda.clear()
self._finish_dialog(ungrateful=True)
else:
ignored_requests, ignored_requests_alt = self._check_system_ignored_request(
self.last_user_actions, sys_act)
# first stage: push operations on top of agenda
if sys_act.type in self.receive_options:
self.receive_options[sys_act.type](sys_act)
# handle missing requests
if ignored_requests:
# repeat unanswered requests from user from last turn
self.agenda.push(ignored_requests)
if ignored_requests_alt:
self.agenda.push(ignored_requests_alt)
# make sure to pick only the requestalt actions (should be 1)
self.num_actions_next_turn = len(ignored_requests_alt)
# make sure that old request actions verifying an offer are removed
self.agenda.remove_actions_of_type(act_type=UserActionType.Request)
# second stage: clean agenda
self.agenda.clean(self.goal)
# agenda might be empty -> add requests again
if self.agenda.is_empty():
if self.goal.is_fulfilled():
self._finish_dialog()
else:
self.agenda.fill_with_requests(self.goal, exclude_name=False)
else:
self.logger.error(
"System Action Type is {}, but I don't know how to handle it!".format(
sys_act.type))
def _receive_welcome(self, sys_act: SysAct):
"""
Processes a welcome action from the system. In this case do nothing
Args:
sys_act (SysAct): the last system action
"""
# do nothing as the first turn is already intercepted
# also, the 'welcome' action is never used in reinforcement learning from the policy
# -> will only, if at all, occur at first turn
def _receive_informbyname(self, sys_act: SysAct):
"""
Processes an informbyname action from the system; checks if the inform matches the
goal constraints and if yes, will add unanswered requests to the agenda
Args:
sys_act (SysAct): the last system action
"""
# check all system informs for offer
inform_list = []
offers = []
for slot, value_list in sys_act.slot_values.items():
for value in value_list:
if slot == 'name':
offers.append(value)
else:
inform_list.append(Constraint(slot, value))
# check offer
if offers:
if self._check_offer(offers, inform_list):
# valid offer
for slot, value in inform_list:
self.goal.fulfill_request(slot, value)
# needed to make sure that not informed constraints (which have been turned into requests)
# will be asked first (before ending the dialog too early)
req_actions_not_in_goal = []
for action in self.agenda.get_actions_of_type(UserActionType.Request):
if action.slot not in self.goal.requests:
req_actions_not_in_goal.append(copy.deepcopy(action))
# goal might be fulfilled now
if (self.goal.is_fulfilled()
and not self.agenda.contains_action_of_type(UserActionType.Inform)
and not req_actions_not_in_goal):
self._finish_dialog()
def _receive_informbyalternatives(self, sys_act: SysAct):
"""
Processes an informbyalternatives action from the system; this is treated like
an inform by name
Args:
sys_act (SysAct): the last system action
"""
# same as inform by name
if self.excluded_venues and self.goal.requests[self.domain.get_primary_key()] is None:
self._receive_informbyname(sys_act)
else:
self._repeat_last_actions()
def _receive_request(self, sys_act: SysAct):
"""
Processes a request action from the system by adding the corresponding answer based
on the current simulator goal.
Args:
sys_act (SysAct): the last system action
"""
for slot, _ in sys_act.slot_values.items():
self.agenda.push(UserAct(
act_type=UserActionType.Inform,
slot=slot, value=self.goal.get_constraint(slot),
score=1.0))
def _receive_confirm(self, sys_act: SysAct):
"""
Processes a confirm action from the system based on information in the user goal
Args:
sys_act (SysAct): the last system action
"""
for slot, _value in sys_act.slot_values.items():
value = _value[0] # there is always only one value
if self.goal.is_inconsistent_constraint_strict(Constraint(slot, value)):
# inform about correct value with some probability, otherwise deny value
if common.random.random() < self.parameters['usermodel']['InformOnConfirm']:
self.agenda.push(UserAct(
act_type=UserActionType.Inform, slot=slot,
value=self.goal.get_constraint(slot),
score=1.0))
else:
self.agenda.push(UserAct(
act_type=UserActionType.NegativeInform, slot=slot, value=value, score=1.0))
else:
# NOTE using inform currently since NLU currently does not support Affirm here and
# NLU would tinker it into an Inform action anyway
# self.agenda.push(
# UserAct(act_type=UserActionType.Affirm, score=1.0))
self.agenda.push(
UserAct(act_type=UserActionType.Inform, slot=slot, value=value, score=1.0))
def _receive_select(self, sys_act: SysAct):
"""
Processes a select action from the system based on the simulation goal
Args:
sys_act (SysAct): the last system action
"""
# handle as request
value_in_goal = False
for slot, values in sys_act.slot_values.items():
for value in values:
# do not consider 'dontcare' as any value
if not self.goal.is_inconsistent_constraint_strict(Constraint(slot, value)):
value_in_goal = True
if value_in_goal:
self._receive_request(sys_act)
else:
assert len(sys_act.slot_values.keys()) == 1, \
"There shall be only one slot in a select action."
# NOTE: currently we support only one slot for select action,
# but this could be changed in the future
slot = list(sys_act.slot_values.keys())[0]
# inform about correct value with some probability
if common.random.random() < self.parameters['usermodel']['InformOnSelect']:
self.agenda.push(UserAct(
act_type=UserActionType.Inform, slot=slot,
value=self.goal.get_constraint(slot),
score=1.0))
for slot, values in sys_act.slot_values.items():
for value in values:
self.agenda.push(UserAct(
act_type=UserActionType.NegativeInform,
slot=slot,
value=value, score=1.0))
def _receive_requestmore(self, sys_act: SysAct):
"""
Processes a requestmore action from the system.
Args:
sys_act (SysAct): the last system action
"""
if self.goal.is_fulfilled():
# end dialog
self._finish_dialog()
elif (not self.agenda.contains_action_of_type(UserActionType.Inform)
and self.goal.requests['name'] is not None):
# venue has been offered and all informs have been issued, but atleast one request slot
# is missing
if self.agenda.is_empty():
self.agenda.fill_with_requests(self.goal)
else:
# make sure that dialog becomes longer
self._repeat_last_actions()
def _receive_bad(self, sys_act:SysAct):
"""
Processes a bad action from the system; repeats the last user action
Args:
sys_act (SysAct): the last system action
"""
# NOTE repeat last action, should never occur on intention-level as long no noise is used
self._repeat_last_actions()
def _receive_confirmrequest(self, sys_act: SysAct):
"""
Processes a confirmrequest action from the system.
Args:
sys_act (SysAct): the last system action
"""
# first slot is confirm, second slot is request
for slot, value in sys_act.slot_values.items():
if value is None:
# system's request action
self._receive_request(
SysAct(act_type=SysActionType.Request, slot_values={slot: None}))
else:
# system's confirm action
# NOTE SysActionType Confirm has single value only
self._receive_confirm(
SysAct(act_type=SysActionType.Confirm, slot_values={slot: [value]}))
def respond(self):
"""
Gets n actions from the agenda, where n is drawn depending on the agenda or a pdf.
"""
# get some actions from the agenda
assert len(self.agenda) > 0, "Agenda is empty, this must not happen at this point!"
if self.num_actions_next_turn > 0:
# use and reset self.num_actions_next_turn if set
num_actions = self.num_actions_next_turn
self.num_actions_next_turn = -1
elif self.agenda.stack[-1].type == UserActionType.Bye:
# pop all actions from agenda since agenda can only contain thanks (optional) and
# bye action
num_actions = -1
else:
# draw amount of actions
num_actions = min(len(self.agenda), common.numpy.random.choice(
[1, 2, 3], p=[.6, .3, .1])) # hardcoded pdf
# get actions from agenda
user_actions = self.agenda.get_actions(num_actions)
# copy needed for repeat action since they might be changed in other modules
self.last_user_actions = copy.deepcopy(user_actions)
for action in user_actions:
if action.type == UserActionType.Inform:
_constraint = Constraint(action.slot, action.value)
# if _constraint in self.goal.constraints:
if action in self.goal.missing_informs:
self.goal.missing_informs.remove(action)
return user_actions
def _finish_dialog(self, ungrateful=False):
"""
Pushes a bye action ontop of the agenda in order to end a dialog. Depending on the user
model, a thankyou action might be added too.
Args:
ungrateful (bool): determines if the user should also say "thanks"; if the dialog ran
too long or the user ran out of patience, ungrateful will be true
"""
self.agenda.clear() # empty agenda
# thank with some probability
# NOTE bye has to be the topmost action on the agenda since we check for it in the
# respond() method
if not ungrateful and common.random.random() < self.parameters['usermodel']['Thank']:
self.agenda.push(UserAct(act_type=UserActionType.Thanks, score=1.0))
self.agenda.push(UserAct(act_type=UserActionType.Bye, score=1.0))
def _repeat_last_actions(self):
"""
Pushes the last user actions ontop of the agenda.
"""
if self.last_user_actions is not None:
self.agenda.push(self.last_user_actions[::-1])
self.num_actions_next_turn = len(self.last_user_actions)
def _alter_constraints(self, constraints, count):
"""
Alters *count* constraints from the given constraints by choosing a new value
(could be also 'dontcare').
"""
constraints_candidates = constraints[:] # copy list
if not constraints_candidates:
for _constraint in self.goal.constraints:
if _constraint.value != 'dontcare':
constraints_candidates.append(Constraint(_constraint.slot, _constraint.value))
else:
# any constraint from the current system actions has to be taken into consideration
# make sure that constraints are part of the goal since noise could have influenced the
# dialog -> given constraints must conform to the current goal
constraints_candidates = list(filter(
lambda x: not self.goal.is_inconsistent_constraint_strict(x),
constraints_candidates))
if not constraints_candidates:
return []
constraints_to_alter = common.numpy.random.choice(
constraints_candidates, count, replace=False)
new_constraints = []
for _constraint in constraints_to_alter:
self.goal.excluded_inf_slot_values[_constraint.slot].add(
_constraint.value)
possible_values = self.goal.inf_slot_values[_constraint.slot][:]
for _value in self.goal.excluded_inf_slot_values[_constraint.slot]:
# remove values which have been tried already
# NOTE values in self.excluded_inf_slot_values should always be in possible_values
# because the same source is used for both and to initialize the goal
possible_values.remove(_value)
if not possible_values:
# add 'dontcare' as last option
possible_values.append('dontcare')
# 'dontcare' value with some probability
if common.random.random() < self.parameters['usermodel']['DontcareIfNoVenue']:
value = 'dontcare'
else:
value = common.numpy.random.choice(possible_values)
if not self.goal.update_constraint(_constraint.slot, value):
# NOTE: this case should never happen!
print(
"The given constraints (probably by the system) are not part of the goal!")
new_constraints.append(Constraint(_constraint.slot, value))
self.logger.dialog_turn(
"Goal altered! {} -> {}.".format(constraints_to_alter, new_constraints))
return new_constraints
def _check_informs(self, informed_constraints_by_system):
""" Checks whether the informs by the system are consistent with the goal and pushes
appropriate actions onto the agenda for inconsistent constraints. """
# check for inconsistent constraints and remove informs of consistent constraints from
# agenda
consistent_with_goal = True
for _constraint in informed_constraints_by_system:
if self.goal.is_inconsistent_constraint(_constraint):
consistent_with_goal = False
self.agenda.push(UserAct(
act_type=UserActionType.Inform,
slot=_constraint.slot,
value=self.goal.get_constraint(_constraint.slot), score=1.0))
else:
self.agenda.remove_actions(UserActionType.Inform, *_constraint)
return consistent_with_goal
def _check_offer(self, offers, informed_constraints_by_system):
""" Checks for an offer and returns True if the offer is valid. """
if not self._check_informs(informed_constraints_by_system):
# reset offer in goal since inconsistencies have been detected and covered
self.goal.requests[self.domain.get_primary_key()] = None
return False
# TODO maybe check for current offer first since alternative with name='none' by system
# would trigger goal change -> what is the correct action in this case?
if offers:
if 'none' not in offers:
# offer was given
# convert informs of values != 'dontcare' to requests
actions_to_convert = list(self.agenda.get_actions_of_type(
UserActionType.Inform, consider_dontcare=False))
if len(self.goal.constraints) > 1 and len(actions_to_convert) == len(self.goal.constraints):
# penalise too early offers
self._repeat_last_actions()
self.num_actions_next_turn = len(self.last_user_actions)
return False
# ask for values of remaining inform slots on agenda - this has two purposes:
# 1. making sure that offer is consistent with goal
# 2. making sure that inconsistent offers prolongate a dialog
for action in actions_to_convert:
self.agenda.push(UserAct(
act_type=UserActionType.Request,
slot=action.slot,
value=None, score=1.0))
self.agenda.remove_actions_of_type(UserActionType.Inform)
if self.goal.requests[self.domain.get_primary_key()] is not None:
if self.goal.requests[self.domain.get_primary_key()] in offers:
# offer is the same, don't change anything but treat offer as valid
return True
else:
# offer is not the same, but did not request a new one
# NOTE with current bst do not (negative) inform about the offer, because
# it will only set the proability to zero -> will not be excluded
# self.agenda.push(UserAct(act_type=UserActionType.NegativeInform,\
# slot=self.domain.get_primary_key(), value=offers[0]))
return False
else:
for _offer in offers:
if _offer not in self.excluded_venues:
# offer is not on the exclusion list (e.g. from reqalt action) and
# there is no current offer
# sometimes ask for alternative
if common.random.random() < self.parameters['usermodel']['ReqAlt']:
self._request_alt(_offer)
return False
else:
self.goal.requests[self.domain.get_primary_key()] = _offer
for _action in self.goal.missing_informs:
# informed constraints by system are definitely consistent with
# goal at this point
if Constraint(_action.slot, _action.value) not in informed_constraints_by_system:
self.agenda.push(UserAct(
act_type=UserActionType.Request,
slot=_action.slot,
value=None))
return True
# no valid offer was given
self._request_alt()
return False
else:
# no offer was given
# TODO add probability to choose number of alternations
altered_constraints = self._alter_constraints(informed_constraints_by_system, 1)
# reset goal push new actions on top of agenda
self.goal.reset()
self.goal.missing_informs = [UserAct(
act_type=UserActionType.Inform,
slot=_constraint.slot,
value=_constraint.value) for _constraint in self.goal.constraints]
for _constraint in altered_constraints:
self.agenda.push(UserAct(
act_type=UserActionType.Inform,
slot=_constraint.slot,
value=_constraint.value,
score=1.0))
self.agenda.clean(self.goal)
return False
return False
def _request_alt(self, offer=None):
"""
Handles the case where a user might want to ask for an alternative offer
"""
# add current offer to exclusion list, reset current offer and request alternative
if offer is not None:
self.excluded_venues.append(offer)
if self.goal.requests[self.domain.get_primary_key()] is not None:
self.excluded_venues.append(self.goal.requests[self.domain.get_primary_key()])
self.goal.requests[self.domain.get_primary_key()] = None
self.goal.reset()
self.agenda.push(UserAct(act_type=UserActionType.RequestAlternatives))
def _check_system_ignored_request(self, user_actions: List[UserAct], sys_act: SysAct):
"""
Make sure that there are no unanswered requests/constraints that got turned into requests
"""
if not user_actions:
# no user_actions -> system ignored nothing
return [], []
requests = [action for action in user_actions if action.type == UserActionType.Request]
if not requests:
# no requests -> system ignored nothing
return [], []
if sys_act.type in [SysActionType.InformByName]:
requests = [request for request in requests if request.slot not in sys_act.slot_values]
requests_alt = [action for action in user_actions if action.type == UserActionType.RequestAlternatives]
if sys_act.type == SysActionType.InformByAlternatives:
offer = sys_act.slot_values[self.domain.get_primary_key()]
if (set(offer) - set(self.excluded_venues)): # and self.goal.requests[self.domain.get_primary_key()] is None:
requests_alt = []
return requests, requests_alt
__init__(self, domain, logger=<DiasysLogger adviser (NOTSET)>)
special
¶
Source code in adviser/services/simulator/simulator.py
def __init__(self, domain: Domain, logger: DiasysLogger = DiasysLogger()):
super(HandcraftedUserSimulator, self).__init__(domain)
# possible system actions
self.receive_options = {SysActionType.Welcome: self._receive_welcome,
SysActionType.InformByName: self._receive_informbyname,
SysActionType.InformByAlternatives:
self._receive_informbyalternatives,
SysActionType.Request: self._receive_request,
SysActionType.Confirm: self._receive_confirm,
SysActionType.Select: self._receive_select,
SysActionType.RequestMore: self._receive_requestmore,
SysActionType.Bad: self._receive_bad,
SysActionType.ConfirmRequest: self._receive_confirmrequest}
# parse config file
self.logger = logger
self.config = configparser.ConfigParser(
inline_comment_prefixes=('#', ';'))
self.config.optionxform = str
self.config.read(os.path.join(os.path.abspath(
os.path.dirname(__file__)), 'usermodel.cfg'))
self.parameters = {}
# goal
self.parameters['goal'] = {}
for key in self.config["goal"]:
val = self.config.get("goal", key)
self.parameters['goal'][key] = float(val)
# usermodel
self.parameters['usermodel'] = {}
for key in self.config["usermodel"]:
val = self.config.get(
"usermodel", key)
if key in ['patience']:
# patience will be sampled on begin of each dialog
self.parameters['usermodel'][key] = [int(x) for x in (
val.replace(' ', '').strip('[]').split(','))]
else:
if val.startswith("[") and val.endswith("]"):
# value is a list to sample the probability from
self.parameters['usermodel'][key] = common.numpy.random.uniform(
*[float(x) for x in val.replace(' ', '').strip('[]').split(',')])
else:
# value is the probability
self.parameters['usermodel'][key] = float(val)
# member declarations
self.turn = 0
self.domain = domain
self.dialog_patience = None
self.patience = None
self.last_user_actions = None
self.last_system_action = None
self.excluded_venues = []
# member definitions
self.goal = Goal(domain, self.parameters['goal'])
self.agenda = Agenda()
self.num_actions_next_turn = -1
dialog_start(self)
¶
Resets the user model at the beginning of a dialog, e.g. draws a new goal and populates the agenda according to the goal.
Source code in adviser/services/simulator/simulator.py
def dialog_start(self):
"""Resets the user model at the beginning of a dialog, e.g. draws a new goal and populates
the agenda according to the goal."""
# self.goal = Goal(self.domain, self.parameters['goal'])
self.goal.init()
self.agenda.init(self.goal)
if self.logger:
self.logger.dialog_turn(
"New goal has constraints {} and requests {}.".format(
self.goal.constraints, self.goal.requests))
self.logger.dialog_turn("New agenda initialized: {}".format(self.agenda))
# add hello action with some probability
if common.random.random() < self.parameters['usermodel']['Greeting']:
self.agenda.push(UserAct(act_type=UserActionType.Hello, score=1.0))
# needed for possibility to reset patience
if len(self.parameters['usermodel']['patience']) == 1:
self.dialog_patience = self.parameters['usermodel']['patience'][0]
else:
self.dialog_patience = common.random.randint(
*self.parameters['usermodel']['patience'])
self.patience = self.dialog_patience
self.last_user_actions = None
self.last_system_action = None
self.excluded_venues = []
self.turn = 0
receive(self, sys_act)
¶
This function makes sure that the agenda reflects all changes needed for the received system action.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sys_act |
SysAct |
The action the system took |
required |
Source code in adviser/services/simulator/simulator.py
def receive(self, sys_act: SysAct):
"""
This function makes sure that the agenda reflects all changes needed for the received
system action.
Args:
sys_act (SysAct): The action the system took
"""
if self.last_system_action is not None:
# check whether system action is the same as before
if sys_act == self.last_system_action:
self.patience -= 1
elif self.parameters['usermodel']['resetPatience']:
self.patience = self.dialog_patience
self.last_system_action = sys_act
if self.patience == 0:
self.logger.dialog_turn("User patience run out, ending dialog.")
self.agenda.clear()
self._finish_dialog(ungrateful=True)
else:
ignored_requests, ignored_requests_alt = self._check_system_ignored_request(
self.last_user_actions, sys_act)
# first stage: push operations on top of agenda
if sys_act.type in self.receive_options:
self.receive_options[sys_act.type](sys_act)
# handle missing requests
if ignored_requests:
# repeat unanswered requests from user from last turn
self.agenda.push(ignored_requests)
if ignored_requests_alt:
self.agenda.push(ignored_requests_alt)
# make sure to pick only the requestalt actions (should be 1)
self.num_actions_next_turn = len(ignored_requests_alt)
# make sure that old request actions verifying an offer are removed
self.agenda.remove_actions_of_type(act_type=UserActionType.Request)
# second stage: clean agenda
self.agenda.clean(self.goal)
# agenda might be empty -> add requests again
if self.agenda.is_empty():
if self.goal.is_fulfilled():
self._finish_dialog()
else:
self.agenda.fill_with_requests(self.goal, exclude_name=False)
else:
self.logger.error(
"System Action Type is {}, but I don't know how to handle it!".format(
sys_act.type))
respond(self)
¶
Gets n actions from the agenda, where n is drawn depending on the agenda or a pdf.
Source code in adviser/services/simulator/simulator.py
def respond(self):
"""
Gets n actions from the agenda, where n is drawn depending on the agenda or a pdf.
"""
# get some actions from the agenda
assert len(self.agenda) > 0, "Agenda is empty, this must not happen at this point!"
if self.num_actions_next_turn > 0:
# use and reset self.num_actions_next_turn if set
num_actions = self.num_actions_next_turn
self.num_actions_next_turn = -1
elif self.agenda.stack[-1].type == UserActionType.Bye:
# pop all actions from agenda since agenda can only contain thanks (optional) and
# bye action
num_actions = -1
else:
# draw amount of actions
num_actions = min(len(self.agenda), common.numpy.random.choice(
[1, 2, 3], p=[.6, .3, .1])) # hardcoded pdf
# get actions from agenda
user_actions = self.agenda.get_actions(num_actions)
# copy needed for repeat action since they might be changed in other modules
self.last_user_actions = copy.deepcopy(user_actions)
for action in user_actions:
if action.type == UserActionType.Inform:
_constraint = Constraint(action.slot, action.value)
# if _constraint in self.goal.constraints:
if action in self.goal.missing_informs:
self.goal.missing_informs.remove(action)
return user_actions
user_turn(self, *args, **kwargs)
¶
Source code in adviser/services/simulator/simulator.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
stats
special
¶
evaluation
¶
ObjectiveReachedEvaluator
¶
Evaluate single turns and complete dialog.
This class assigns a negative reward to each turn. In case the user's goal could be satisfied (meaning a matching database entry was found), a large final reward is returned.
Only needed when training against a simulator.
Source code in adviser/services/stats/evaluation.py
class ObjectiveReachedEvaluator(object):
""" Evaluate single turns and complete dialog.
This class assigns a negative reward to each turn.
In case the user's goal could be satisfied (meaning a matching database
entry was found), a large final reward is returned.
Only needed when training against a simulator.
"""
def __init__(self, domain: Domain, turn_reward=-1, success_reward=20,
logger: DiasysLogger = DiasysLogger()):
assert turn_reward <= 0.0, 'the turn reward should be negative'
self.domain = domain
self.turn_reward = turn_reward
self.success_reward = success_reward
self.logger = logger
def get_turn_reward(self):
"""
Get the reward for one turn
Returns:
(int): the reward for the given turn
"""
return self.turn_reward
def get_final_reward(self, sim_goal: Goal, logging=True):
"""
Check whether the user's goal was completed.
Args:
sim_goal (Goal): the simulation's goal
logging (bool): whether or not the evaluation results should be logged
Returns:
float: Reward - the final reward (0 (unsuccessful) or 20 (successful))
bool: Success
"""
requests = sim_goal.requests
constraints = sim_goal.constraints # list of constraints
# self.logger.dialog_turn("User Goal > " + str(sim_goal.constraints))
if None in requests.values() or requests['name'] == 'none':
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
# TODO think about this more? if goals not satisfiable,
# should system take the blame? not fair
# print(requests['name'])
db_matches = self.domain.find_info_about_entity(
entity_id=requests['name'],
requested_slots=[constraint.slot for constraint in constraints])
if db_matches:
match = db_matches[0]
for const in constraints:
if const.value != match[const.slot] and const.value != 'dontcare':
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
if logging:
self.logger.dialog_turn("Success with user requests \n{}".format(requests))
return 20.0, True
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
__init__(self, domain, turn_reward=-1, success_reward=20, logger=<DiasysLogger adviser (NOTSET)>)
special
¶
Source code in adviser/services/stats/evaluation.py
get_final_reward(self, sim_goal, logging=True)
¶
Check whether the user's goal was completed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sim_goal |
Goal |
the simulation's goal |
required |
logging |
bool |
whether or not the evaluation results should be logged |
True |
Returns:
Type | Description |
---|---|
float |
Reward - the final reward (0 (unsuccessful) or 20 (successful)) bool: Success |
Source code in adviser/services/stats/evaluation.py
def get_final_reward(self, sim_goal: Goal, logging=True):
"""
Check whether the user's goal was completed.
Args:
sim_goal (Goal): the simulation's goal
logging (bool): whether or not the evaluation results should be logged
Returns:
float: Reward - the final reward (0 (unsuccessful) or 20 (successful))
bool: Success
"""
requests = sim_goal.requests
constraints = sim_goal.constraints # list of constraints
# self.logger.dialog_turn("User Goal > " + str(sim_goal.constraints))
if None in requests.values() or requests['name'] == 'none':
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
# TODO think about this more? if goals not satisfiable,
# should system take the blame? not fair
# print(requests['name'])
db_matches = self.domain.find_info_about_entity(
entity_id=requests['name'],
requested_slots=[constraint.slot for constraint in constraints])
if db_matches:
match = db_matches[0]
for const in constraints:
if const.value != match[const.slot] and const.value != 'dontcare':
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
if logging:
self.logger.dialog_turn("Success with user requests \n{}".format(requests))
return 20.0, True
if logging:
self.logger.dialog_turn("Fail with user requests \n{}".format(requests))
return 0.0, False
get_turn_reward(self)
¶
PolicyEvaluator (Service)
¶
Policy evaluation module
Plug this module into the dialog graph (somewhere after the policy), and policy metrics like success rate and reward will be recorded.
Source code in adviser/services/stats/evaluation.py
class PolicyEvaluator(Service):
""" Policy evaluation module
Plug this module into the dialog graph (somewhere *after* the policy),
and policy metrics like success rate and reward will be recorded.
"""
def __init__(self, domain: Domain, subgraph: dict = None, use_tensorboard=False,
experiment_name: str = '', turn_reward=-1, success_reward=20,
logger: DiasysLogger = DiasysLogger(), summary_writer=None):
"""
Keyword Arguments:
use_tensorboard {bool} -- [If true, metrics will be written to
tensorboard in a *runs* directory]
(default: {False})
experiment_name {str} -- [Name suffix for the log files]
(default: {''})
turn_reward {float} -- [Reward for one turn - usually negative to
penalize dialog length] (default: {-1})
success_reward {float} -- [Reward of the final transition if the
dialog goal was reached] (default: {20})
"""
super(PolicyEvaluator, self).__init__(domain)
self.logger = logger
self.epoch = 0
self.evaluator = ObjectiveReachedEvaluator(
domain, turn_reward=turn_reward, success_reward=success_reward, logger=logger)
self.writer = summary_writer
self.total_train_dialogs = 0
self.total_eval_dialogs = 0
self.epoch_train_dialogs = 0
self.epoch_eval_dialogs = 0
self.train_rewards = []
self.eval_rewards = []
self.train_success = []
self.eval_success = []
self.train_turns = []
self.eval_turns = []
self.is_training = False
@PublishSubscribe(sub_topics=['sys_act'], pub_topics=["sys_turn_over"])
def evaluate_turn(self, sys_act: SysAct = None):
"""
Evaluates the reward for a given turn
Args:
sys_act (SysAct): the system action
Returns:
(bool): A signal representing the end of a complete dialog turn
"""
self.dialog_reward += self.evaluator.get_turn_reward()
self.dialog_turns += 1
return {"sys_turn_over": True}
def dialog_start(self, dialog_start=False):
"""
Clears the state of the evaluator in preparation to start a new dialog
"""
self.dialog_reward = 0.0
self.dialog_turns = 0
def train(self):
"""
sets the evaluator in train mode
"""
self.is_training = True
def eval(self):
"""
sets teh evaluator in eval mode
"""
self.is_training = False
@PublishSubscribe(sub_topics=["sim_goal"], pub_topics=["dialog_end"])
def end_dialog(self, sim_goal: Goal):
"""
Method for handling the end of a dialog; calculates the the final reward.
Args:
sim_goal (Goal): the simulation goal to evaluate against
Returns:
(dict): a dictionary where the key is "dialog_end" and the value is true
"""
if self.is_training:
self.total_train_dialogs += 1
self.epoch_train_dialogs += 1
else:
self.total_eval_dialogs += 1
self.epoch_eval_dialogs += 1
if sim_goal is None:
# real user interaction, no simulator - don't have to evaluate
# anything, just reset counters
return {"dialog_end": True}
final_reward, success = self.evaluator.get_final_reward(sim_goal)
self.dialog_reward += final_reward
if self.is_training:
self.train_rewards.append(self.dialog_reward)
self.train_success.append(int(success))
self.train_turns.append(self.dialog_turns)
if self.writer is not None:
self.writer.add_scalar('train/episode_reward', self.dialog_reward,
self.total_train_dialogs)
else:
self.eval_rewards.append(self.dialog_reward)
self.eval_success.append(int(success))
self.eval_turns.append(self.dialog_turns)
if self.writer is not None:
self.writer.add_scalar('eval/episode_reward', self.dialog_reward,
self.total_eval_dialogs)
return {"dialog_end": True}
def start_epoch(self):
"""
Handles resetting variables between epochs
"""
# global statistics
self.epoch_train_dialogs = 0
self.epoch_eval_dialogs = 0
self.train_rewards = []
self.eval_rewards = []
self.train_success = []
self.eval_success = []
self.train_turns = []
self.eval_turns = []
self.epoch += 1
self.logger.info("###\n### EPOCH" + str(self.epoch) + " ###\n###")
def end_epoch(self):
"""
Handles calculating statistics at the end of an epoch
"""
if self.logger:
if self.epoch_train_dialogs > 0:
self.logger.result(" ### Train ###")
self.logger.result("# Num Dialogs " + str(self.epoch_train_dialogs))
self.logger.result("# Avg Turns " + str(sum(self.train_turns) / self.epoch_train_dialogs))
self.logger.result("# Avg Success " + str(sum(self.train_success) / self.epoch_train_dialogs))
self.logger.result("# Avg Reward " + str(sum(self.train_rewards) / self.epoch_train_dialogs))
if self.epoch_eval_dialogs > 0:
self.logger.result(" ### Eval ###")
self.logger.result("# Num Dialogs " + str(self.epoch_eval_dialogs))
self.logger.result("# Avg Turns " + str(sum(self.eval_turns) / self.epoch_eval_dialogs))
self.logger.result("# Avg Success " + str(sum(self.eval_success) / self.epoch_eval_dialogs))
self.logger.result("# Avg Reward " + str(sum(self.eval_rewards) / self.epoch_eval_dialogs))
if self.is_training:
return {'num_dialogs': self.epoch_train_dialogs,
'turns': sum(self.train_turns) / self.epoch_train_dialogs,
'success': float(sum(self.train_success)) / self.epoch_train_dialogs,
'reward': float(sum(self.train_rewards)) / self.epoch_train_dialogs}
else:
return {'num_dialogs': self.epoch_eval_dialogs,
'turns': sum(self.eval_turns) / self.epoch_eval_dialogs,
'success': float(sum(self.eval_success)) / self.epoch_eval_dialogs,
'reward': float(sum(self.eval_rewards)) / self.epoch_eval_dialogs}
__init__(self, domain, subgraph=None, use_tensorboard=False, experiment_name='', turn_reward=-1, success_reward=20, logger=<DiasysLogger adviser (NOTSET)>, summary_writer=None)
special
¶
Keyword arguments:
Name | Type | Description |
---|---|---|
use_tensorboard |
{bool} -- [If true, metrics will be written to
tensorboard in a *runs* directory]
(default |
{False}) |
experiment_name |
{str} -- [Name suffix for the log files]
(default |
{''}) |
turn_reward |
{float} -- [Reward for one turn - usually negative to
penalize dialog length] (default |
{-1}) |
success_reward |
{float} -- [Reward of the final transition if the
dialog goal was reached] (default |
{20}) |
Source code in adviser/services/stats/evaluation.py
def __init__(self, domain: Domain, subgraph: dict = None, use_tensorboard=False,
experiment_name: str = '', turn_reward=-1, success_reward=20,
logger: DiasysLogger = DiasysLogger(), summary_writer=None):
"""
Keyword Arguments:
use_tensorboard {bool} -- [If true, metrics will be written to
tensorboard in a *runs* directory]
(default: {False})
experiment_name {str} -- [Name suffix for the log files]
(default: {''})
turn_reward {float} -- [Reward for one turn - usually negative to
penalize dialog length] (default: {-1})
success_reward {float} -- [Reward of the final transition if the
dialog goal was reached] (default: {20})
"""
super(PolicyEvaluator, self).__init__(domain)
self.logger = logger
self.epoch = 0
self.evaluator = ObjectiveReachedEvaluator(
domain, turn_reward=turn_reward, success_reward=success_reward, logger=logger)
self.writer = summary_writer
self.total_train_dialogs = 0
self.total_eval_dialogs = 0
self.epoch_train_dialogs = 0
self.epoch_eval_dialogs = 0
self.train_rewards = []
self.eval_rewards = []
self.train_success = []
self.eval_success = []
self.train_turns = []
self.eval_turns = []
self.is_training = False
dialog_start(self, dialog_start=False)
¶
end_dialog(self, *args, **kwargs)
¶
Source code in adviser/services/stats/evaluation.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
end_epoch(self)
¶
Handles calculating statistics at the end of an epoch
Source code in adviser/services/stats/evaluation.py
def end_epoch(self):
"""
Handles calculating statistics at the end of an epoch
"""
if self.logger:
if self.epoch_train_dialogs > 0:
self.logger.result(" ### Train ###")
self.logger.result("# Num Dialogs " + str(self.epoch_train_dialogs))
self.logger.result("# Avg Turns " + str(sum(self.train_turns) / self.epoch_train_dialogs))
self.logger.result("# Avg Success " + str(sum(self.train_success) / self.epoch_train_dialogs))
self.logger.result("# Avg Reward " + str(sum(self.train_rewards) / self.epoch_train_dialogs))
if self.epoch_eval_dialogs > 0:
self.logger.result(" ### Eval ###")
self.logger.result("# Num Dialogs " + str(self.epoch_eval_dialogs))
self.logger.result("# Avg Turns " + str(sum(self.eval_turns) / self.epoch_eval_dialogs))
self.logger.result("# Avg Success " + str(sum(self.eval_success) / self.epoch_eval_dialogs))
self.logger.result("# Avg Reward " + str(sum(self.eval_rewards) / self.epoch_eval_dialogs))
if self.is_training:
return {'num_dialogs': self.epoch_train_dialogs,
'turns': sum(self.train_turns) / self.epoch_train_dialogs,
'success': float(sum(self.train_success)) / self.epoch_train_dialogs,
'reward': float(sum(self.train_rewards)) / self.epoch_train_dialogs}
else:
return {'num_dialogs': self.epoch_eval_dialogs,
'turns': sum(self.eval_turns) / self.epoch_eval_dialogs,
'success': float(sum(self.eval_success)) / self.epoch_eval_dialogs,
'reward': float(sum(self.eval_rewards)) / self.epoch_eval_dialogs}
eval(self)
¶
evaluate_turn(self, *args, **kwargs)
¶
Source code in adviser/services/stats/evaluation.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result
start_epoch(self)
¶
Handles resetting variables between epochs
Source code in adviser/services/stats/evaluation.py
def start_epoch(self):
"""
Handles resetting variables between epochs
"""
# global statistics
self.epoch_train_dialogs = 0
self.epoch_eval_dialogs = 0
self.train_rewards = []
self.eval_rewards = []
self.train_success = []
self.eval_success = []
self.train_turns = []
self.eval_turns = []
self.epoch += 1
self.logger.info("###\n### EPOCH" + str(self.epoch) + " ###\n###")
train(self)
¶
ust
special
¶
ust
¶
HandcraftedUST (Service)
¶
A rule-based approach on user state tracking. Currently very minimalist
Source code in adviser/services/ust/ust.py
class HandcraftedUST(Service):
"""
A rule-based approach on user state tracking. Currently very minimalist
"""
def __init__(self, domain=None, logger=None):
Service.__init__(self, domain=domain)
self.logger = logger
self.us = UserState()
@PublishSubscribe(sub_topics=["emotion", "engagement"], pub_topics=["userstate"])
def update_emotion(self, emotion: EmotionType = None, engagement: EngagementType = None) \
-> dict(userstate=UserState):
"""
Function for updating the userstate (which tracks the system's knowledge about the
user's emotions/engagement
Args:
emotion (EmotionType): what emotion has been identified for the user
engagement (list): a list of UserAct objects mapped from the user's last utterance
Returns:
(dict): a dictionary with the key "userstate" and the value a UserState object
"""
# save last turn to memory
self.us.start_new_turn()
self.us["engagement"] = engagement
self.us["emotion"] = emotion
return {'userstate': self.us}
def dialog_start(self):
"""
Resets the user state so it is ready for a new dialog
"""
# initialize belief state
self.us = UserState()
__init__(self, domain=None, logger=None)
special
¶
dialog_start(self)
¶
update_emotion(self, *args, **kwargs)
¶
Source code in adviser/services/ust/ust.py
def delegate(self, *args, **kwargs):
func_inst = getattr(self, func.__name__)
callargs = list(args)
if self in callargs: # remove self when in *args, because already known to function
callargs.remove(self)
result = func(self, *callargs, **kwargs)
if result:
# fix! (user could have multiple "/" characters in topic - only use last one )
domains = {res.split("/")[0]: res.split("/")[1] if "/" in res else "" for res in result}
result = {key.split("/")[0]: result[key] for key in result}
if func_inst not in self._publish_sockets:
# not a publisher, just normal function
return result
socket = self._publish_sockets[func_inst]
domain = self._domain_name
if socket and result:
# publish messages
for topic in pub_topics:
# for topic in result: # NOTE publish any returned value in dict with it's key as topic
if topic in result:
domain = domain if domain else domains[topic]
topic_domain_str = f"{topic}/{domain}" if domain else topic
if topic in self._pub_topic_domains:
topic_domain_str = f"{topic}/{self._pub_topic_domains[topic]}" if self._pub_topic_domains[topic] else topic
_send_msg(socket, topic_domain_str, result[topic])
if self.debug_logger:
self.debug_logger.info(
f"- (DS): sent message from {func} to topic {topic_domain_str}:\n {result[topic]}")
return result