# -*- coding: utf-8 -*-
from . import sequence as seq
from . import nodeframe
import torch
import torch.nn as nn
import numpy as np
#%%ANN + Gaussian - for one data set & one parameter
[docs]class MLPGaussian(torch.nn.Module):
def __init__(self, node_in=100, node_out=1, hidden_layer=3, nodes=None,
activation_func='Softplus'):
super(MLPGaussian, self).__init__()
self.node_in = node_in
self.node_out = node_out
if nodes is None:
nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out*2, hidden_layer=hidden_layer, get_allNode=True)
self.fc = seq.LinearSeq(nodes,mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x):
x = self.fc(x)
params = x[:, :self.node_out]
params = params.view(-1, self.node_out)
sigma = nn.Softplus()(x[:, self.node_out:])
sigma = sigma.view(-1, self.node_out)
return params, sigma
[docs]def gaussian_loss(params, sigma, target):
"""
https://en.wikipedia.org/wiki/Normal_distribution
return:
"""
sqrt_2pi = torch.sqrt(torch.tensor(2*np.pi))
prob = -0.5*((target-params)/sigma)**2 - torch.log(sigma) - torch.log(sqrt_2pi)
prob = torch.sum(prob, dim=1) #dim=1 means sum for parameters dimension
return torch.mean(-prob)
#%%ANN + Multivariate Gaussian - for one data set & multiple parameters
[docs]class MLPMultivariateGaussian(torch.nn.Module):
def __init__(self, node_in=100, node_out=2, hidden_layer=3, nodes=None,
activation_func='Softplus'):
super(MLPMultivariateGaussian, self).__init__()
self.node_in = node_in
self.node_out = node_out
if nodes is None:
nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out*2+(node_out**2-node_out)//2, hidden_layer=hidden_layer, get_allNode=True)
self.fc = seq.LinearSeq(nodes,mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x):
x = self.fc(x)
params = x[:, :self.node_out]
params = params.view(-1, self.node_out, 1)
cholesky_diag = nn.Softplus()(x[:, self.node_out:self.node_out*2])
cholesky_diag = cholesky_diag.view(-1, self.node_out)
cholesky_factor = torch.diag_embed(cholesky_diag)
cholesky_offDiag = x[:, self.node_out*2:]
cholesky_offDiag = cholesky_offDiag.view(-1, (self.node_out**2-self.node_out)//2)
upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1)
cholesky_factor[:, upper_index[0], upper_index[1]] = cholesky_offDiag
return params, cholesky_factor
#need further research
[docs]class MLPMultivariateGaussian_AvgMultiNoise(torch.nn.Module):
def __init__(self, node_in=100, node_out=2, hidden_layer=3, nodes=None,
activation_func='Softplus'):
super(MLPMultivariateGaussian_AvgMultiNoise, self).__init__()
self.node_in = node_in
self.node_out = node_out
if nodes is None:
nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out*2+(node_out**2-node_out)//2, hidden_layer=hidden_layer, get_allNode=True)
self.fc = seq.LinearSeq(nodes,mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
#need further research
[docs] def forward(self, x, multi_noise=1):
x = self.fc(x)
params = x[:, :self.node_out]
params = params.view(-1, self.node_out, 1)
cholesky_diag = nn.Softplus()(x[:, self.node_out:self.node_out*2])
cholesky_diag = cholesky_diag.view(-1, self.node_out)
cholesky_factor = torch.diag_embed(cholesky_diag)
cholesky_offDiag = x[:, self.node_out*2:]
cholesky_offDiag = cholesky_offDiag.view(-1, (self.node_out**2-self.node_out)//2)
upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1)
cholesky_factor[:, upper_index[0], upper_index[1]] = cholesky_offDiag
if multi_noise>1:
cholesky_factor_chunk = torch.chunk(cholesky_factor, multi_noise, dim=0)
cholesky_factor = cholesky_factor_chunk[0]
for i in range(multi_noise-1):
cholesky_factor = cholesky_factor + cholesky_factor_chunk[i+1]
cholesky_factor = cholesky_factor / torch.sqrt(torch.tensor(multi_noise))
cholesky_factor = cholesky_factor.repeat(multi_noise, 1, 1)
return params, cholesky_factor
[docs]def multivariateGaussian_loss(params, cholesky_factor, target):
target = target.unsqueeze(-1)
diff = target - params
params_n = cholesky_factor.size(-1)
sqrt_2pi = torch.sqrt(torch.tensor(2*np.pi)**params_n)
#learn Cholesky factor, here cholesky_factor is Cholesky factor of the inverse covariance matrix
#see arXiv:2003.05739
log_det_2 = torch.sum(torch.log(torch.diagonal(cholesky_factor, dim1=1, dim2=2)), dim=1)
comb = torch.matmul(cholesky_factor, diff)
prob = -0.5*torch.matmul(comb.transpose(1,2), comb)[:,0,0] + log_det_2 - torch.log(sqrt_2pi) #note: cov_mul[:,0,0]
return torch.mean(-prob)
#%% multi-branch network + (Multivariate) Gaussian - for multiple data sets & one (multiple) parameter
[docs]class MultiBranchMLPGaussian(nn.Module):
def __init__(self, nodes_in=[100,100,100], node_out=2, branch_hiddenLayer=1,
trunk_hiddenLayer=1, nodes_all=None, activation_func='Softplus'):
super(MultiBranchMLPGaussian, self).__init__()
self.nodes_in = nodes_in
self.node_out = node_out
if nodes_all is None:
nodes_all = []
branches_out = []
fc_hidden = branch_hiddenLayer*2 + 1
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary
fc_out = node_out*2
for i in range(len(nodes_in)):
fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True)
nodes_branch = fc_node[:branch_hiddenLayer+2]
nodes_all.append(nodes_branch)
branches_out.append(nodes_branch[-1])
nodes_all.append(nodeframe.decreasingNode(node_in=sum(branches_out), node_out=fc_out, hidden_layer=trunk_hiddenLayer, get_allNode=True))
self.branch_n = len(nodes_in)
for i in range(self.branch_n):
exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,\
finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1))
self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive='None',mainBN=True,
finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all):
x1 = self.branch1(x_all[0])
x_comb = x1
for i in range(1, self.branch_n-1+1):
x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1
x_comb = torch.cat((x_comb, x_n),1)
x = self.trunk(x_comb)
params = x[:, :self.node_out]
params = params.view(-1, self.node_out)
sigma = nn.Softplus()(x[:, self.node_out:])
sigma = sigma.view(-1, self.node_out)
return params, sigma
[docs]class MultiBranchMLPMultivariateGaussian(nn.Module):
def __init__(self, nodes_in=[100,100,100], node_out=2, branch_hiddenLayer=1,
trunk_hiddenLayer=1, nodes_all=None, activation_func='Softplus'):
super(MultiBranchMLPMultivariateGaussian, self).__init__()
self.nodes_in = nodes_in
self.node_out = node_out
if nodes_all is None:
#method 1
nodes_all = []
branches_out = []
fc_hidden = branch_hiddenLayer*2 + 1
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary
fc_out = node_out*2+(node_out**2-node_out)//2
for i in range(len(nodes_in)):
fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True)
nodes_branch = fc_node[:branch_hiddenLayer+2]
nodes_all.append(nodes_branch)
branches_out.append(nodes_branch[-1])
nodes_all.append(nodeframe.decreasingNode(node_in=sum(branches_out), node_out=fc_out, hidden_layer=trunk_hiddenLayer, get_allNode=True))
# #method 2
# nodes_all = []
# branches_out = []
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1
# fc_out = node_out*2+(node_out**2-node_out)//2
# fc_hidd_node = nodeframe.decreasingNode(node_in=sum(nodes_in), node_out=fc_out, hidden_layer=fc_hidden, get_allNode=False)
# fc_hidd_node_split = split_nodes(fc_hidd_node[:branch_hiddenLayer+1], weight=[nodes_in[i]/sum(nodes_in) for i in range(len(nodes_in))])
# for i in range(len(nodes_in)):
# branch_node = [nodes_in[i]] + fc_hidd_node_split[i]
# nodes_all.append(branch_node)
# branches_out.append(branch_node[-1])
# trunk_node = [sum(branches_out)] + list(fc_hidd_node[branch_hiddenLayer+1:]) + [fc_out]
# nodes_all.append(trunk_node)
# #method 3
# nodes_all = []
# nodes_comb = []
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1
# fc_out = node_out*2+(node_out**2-node_out)//2
# for i in range(len(nodes_in)):
# fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True)
# print(fc_node)
# branch_node = fc_node[:branch_hiddenLayer+2]
# nodes_all.append(branch_node)
# nodes_comb.append(fc_node[branch_hiddenLayer+1:-1])
# trunk_node = list(np.sum(np.array(nodes_comb), axis=0)) + [fc_out]
# nodes_all.append(trunk_node)
self.branch_n = len(nodes_in)
for i in range(self.branch_n):
exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,\
finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1))
self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive='None',mainBN=True,
finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all):
x1 = self.branch1(x_all[0])
x_comb = x1
for i in range(1, self.branch_n-1+1):
x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1
x_comb = torch.cat((x_comb, x_n),1)
x = self.trunk(x_comb)
params = x[:, :self.node_out]
params = params.view(-1, self.node_out, 1)
cholesky_diag = nn.Softplus()(x[:, self.node_out:self.node_out*2])
cholesky_diag = cholesky_diag.view(-1, self.node_out)
cholesky_factor = torch.diag_embed(cholesky_diag)
cholesky_offDiag = x[:, self.node_out*2:]
cholesky_offDiag = cholesky_offDiag.view(-1, (self.node_out**2-self.node_out)//2)
upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1)
cholesky_factor[:, upper_index[0], upper_index[1]] = cholesky_offDiag
return params, cholesky_factor
#%% loss functions
[docs]def loss_funcs(params_n):
if params_n==1:
return gaussian_loss
else:
return multivariateGaussian_loss
#%% Branch network
[docs]class Branch(nn.Module):
def __init__(self,):
super(Branch, self).__init__()
pass