Source code for colfi.fcnet_mdn

# -*- coding: utf-8 -*-

from . import sequence as seq
from . import nodeframe
import torch
import torch.nn as nn
from torch.distributions import Categorical
import numpy as np
import warnings


#%%Gaussian Mixture Density Network - for one data set & one parameter
[docs]class GaussianMDN(torch.nn.Module): def __init__(self, node_in=100, node_out=1, hidden_layer=3, comp_n=3, nodes=None, activation_func='Softplus'): super(GaussianMDN, self).__init__() self.node_in = node_in self.node_out = node_out self.comp_n = comp_n if nodes is None: #each parameter has independent omega nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out*comp_n*3, hidden_layer=hidden_layer, get_allNode=True) self.fcnet = seq.LinearSeq(nodes,mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq()
[docs] def forward(self, x): x = self.fcnet(x) omega = nn.Softmax(dim=1)(x[:,:self.node_out*self.comp_n]) omega = omega.view(-1, self.node_out, self.comp_n) mu = x[:,self.node_out*self.comp_n:self.node_out*self.comp_n*2] mu = mu.view(-1, self.node_out, self.comp_n) sigma = nn.Softplus()(x[:,self.node_out*self.comp_n*2:]) sigma = sigma.view(-1, self.node_out, self.comp_n) return omega, mu, sigma
[docs]def gaussian_PDF(mu, sigma, target, log=True): """ https://en.wikipedia.org/wiki/Normal_distribution return: """ sqrt_2pi = torch.sqrt(torch.tensor(2*np.pi)) if log: prob = -0.5*((target-mu)/sigma)**2 - torch.log(sigma) - torch.log(sqrt_2pi) else: prob = torch.exp(-0.5*((target-mu)/sigma)**2) / sigma / sqrt_2pi return prob
[docs]def gaussian_loss(omega, mu, sigma, target, logsumexp=True): """ ``torch.logsumexp`` will help us to avoid a lot of numerical instabilities in the training process, see: https://deep-and-shallow.com/2021/03/20/mixture-density-networks-probabilistic-regression-for-uncertainty-estimation/ Note: It is better to set logsumexp=True """ target = target.unsqueeze(2).expand_as(sigma) if logsumexp: log_prob = gaussian_PDF(mu, sigma, target, log=logsumexp) log_omega = torch.log(omega) prob = torch.logsumexp(log_omega+log_prob, dim=2) #dim=2 means sum for comp_n dimension prob = -torch.sum(prob, dim=1) #sum ln(PDF) of all parameters else: prob = omega * gaussian_PDF(mu, sigma, target, log=logsumexp) prob = torch.sum(prob, dim=2) #sum for comp_n dimension prob = -torch.log(torch.prod(prob, dim=1)) #product PDF of all parameters return torch.mean(prob)
[docs]def gaussian_sampler(omega, mu, sigma, chain_leng=10000): """Draw samples from a MoG. """ omega = omega.expand([chain_leng] + list(omega.size())[1:]) mu = mu.expand([chain_leng] + list(mu.size())[1:]) sigma = sigma.expand([chain_leng] + list(sigma.size())[1:]) omegas = Categorical(omega).sample().view(omega.size(0), omega.size(1), 1) mus = mu.detach().gather(2, omegas).squeeze() sigmas = sigma.gather(2, omegas).detach().squeeze() samples_uncorr = torch.distributions.normal.Normal(mus, sigmas).sample() return samples_uncorr
#%%Multivariate Gaussian Mixture Density Network - for one data set & multiple parameters
[docs]class MultivariateGaussianMDN(torch.nn.Module): def __init__(self, node_in=100, node_out=2, hidden_layer=3, comp_n=3, nodes=None, activation_func='Softplus'): super(MultivariateGaussianMDN, self).__init__() self.node_in = node_in self.node_out = node_out self.comp_n = comp_n if nodes is None: nodes = nodeframe.decreasingNode(node_in=node_in, node_out=comp_n+node_out*comp_n*2+comp_n*(node_out**2-node_out)//2, hidden_layer=hidden_layer, get_allNode=True) self.fcnet = seq.LinearSeq(nodes, mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq()
[docs] def forward(self, x): x = self.fcnet(x) omega = nn.Softmax(dim=1)(x[:,:self.comp_n]) mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] mu = mu.view(-1, self.comp_n, self.node_out, 1) cholesky_diag = x[:,self.comp_n+self.node_out*self.comp_n:self.comp_n+self.node_out*self.comp_n*2] cholesky_diag = nn.Softplus()(cholesky_diag) cholesky_diag = cholesky_diag.view(-1, self.comp_n, self.node_out) cholesky_factor = torch.diag_embed(cholesky_diag) cholesky_offDiag = x[:,self.comp_n+self.node_out*self.comp_n*2:] cholesky_offDiag = cholesky_offDiag.view(-1, self.comp_n, (self.node_out**2-self.node_out)//2) upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1) cholesky_factor[:,:, upper_index[0], upper_index[1]] = cholesky_offDiag return omega, mu, cholesky_factor
#need further research to ensure the corresponding equations are correct
[docs]class MultivariateGaussianMDN_AvgMultiNoise(torch.nn.Module): """The difference between this class and MultivariateGaussianMDN is the forward function, where outputs caused by multiple noises are averaged.""" def __init__(self, node_in=100, node_out=2, hidden_layer=3, comp_n=3, nodes=None, activation_func='Softplus'): super(MultivariateGaussianMDN_AvgMultiNoise, self).__init__() self.node_in = node_in self.node_out = node_out self.comp_n = comp_n if nodes is None: nodes = nodeframe.decreasingNode(node_in=node_in, node_out=comp_n+node_out*comp_n*2+comp_n*(node_out**2-node_out)//2, hidden_layer=hidden_layer, get_allNode=True) self.fcnet = seq.LinearSeq(nodes, mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq() #need further research
[docs] def forward(self, x, multi_noise=1): #the setting multi_noise=1 here is used for the prediction process # # method 1, good # x = self.fcnet(x) # omega = x[:,:self.comp_n] # # omega = nn.Softmax(dim=1)(omega) #test, not good # mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] # mu = mu.view(-1, self.comp_n, self.node_out, 1) # cholesky_diag = x[:,self.comp_n+self.node_out*self.comp_n:self.comp_n+self.node_out*self.comp_n*2] # cholesky_diag = nn.Softplus()(cholesky_diag) # cholesky_diag = cholesky_diag.view(-1, self.comp_n, self.node_out) # cholesky_factor = torch.diag_embed(cholesky_diag) # cholesky_offDiag = x[:,self.comp_n+self.node_out*self.comp_n*2:] # cholesky_offDiag = cholesky_offDiag.view(-1, self.comp_n, (self.node_out**2-self.node_out)//2) # upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1) # cholesky_factor[:,:, upper_index[0], upper_index[1]] = cholesky_offDiag # if multi_noise>1: # omega_chunk = torch.chunk(omega, multi_noise, dim=0) # omega = omega_chunk[0] # mu_chunk = torch.chunk(mu, multi_noise, dim=0) # mu = mu_chunk[0] # cholesky_factor_chunk = torch.chunk(cholesky_factor, multi_noise, dim=0) # cholesky_factor = cholesky_factor_chunk[0] # # cholesky_factor = cholesky_factor_chunk[0] / torch.sqrt(torch.tensor(multi_noise)) #bad, why? # cov_inv = torch.matmul(cholesky_factor.transpose(2,3), cholesky_factor) #/ multi_noise, #why can't use multi_noise? # for i in range(multi_noise-1): # omega = omega + omega_chunk[i+1] # mu = mu + mu_chunk[i+1] # cholesky_factor = cholesky_factor_chunk[i+1] # # cholesky_factor = cholesky_factor_chunk[i+1] / torch.sqrt(torch.tensor(multi_noise)) #bad # cov_inv = cov_inv + torch.matmul(cholesky_factor.transpose(2,3), cholesky_factor) #/ multi_noise # omega = omega / multi_noise # mu = mu / multi_noise # cholesky_factor = torch.cholesky(cov_inv, upper=True) # omega = nn.Softmax(dim=1)(omega) #method 2, good x = self.fcnet(x) omega = x[:,:self.comp_n] # omega = nn.Softmax(dim=1)(omega) #test, good mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] mu = mu.view(-1, self.comp_n, self.node_out, 1) cholesky_diag = x[:,self.comp_n+self.node_out*self.comp_n:self.comp_n+self.node_out*self.comp_n*2] cholesky_diag = nn.Softplus()(cholesky_diag) cholesky_diag = cholesky_diag.view(-1, self.comp_n, self.node_out) cholesky_factor = torch.diag_embed(cholesky_diag) cholesky_offDiag = x[:,self.comp_n+self.node_out*self.comp_n*2:] cholesky_offDiag = cholesky_offDiag.view(-1, self.comp_n, (self.node_out**2-self.node_out)//2) upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1) cholesky_factor[:,:, upper_index[0], upper_index[1]] = cholesky_offDiag if multi_noise>1: omega_chunk = torch.chunk(omega, multi_noise, dim=0) omega = omega_chunk[0] mu_chunk = torch.chunk(mu, multi_noise, dim=0) mu = mu_chunk[0] cholesky_factor_chunk = torch.chunk(cholesky_factor, multi_noise, dim=0) cholesky_factor = cholesky_factor_chunk[0] for i in range(multi_noise-1): omega = omega + omega_chunk[i+1] mu = mu + mu_chunk[i+1] cholesky_factor = cholesky_factor + cholesky_factor_chunk[i+1] omega = omega / multi_noise mu = mu / multi_noise cholesky_factor = cholesky_factor / torch.sqrt(torch.tensor(multi_noise)) omega = nn.Softmax(dim=1)(omega) return omega, mu, cholesky_factor
#test, not good, remove?
[docs]class MultivariateGaussianMDN_AvgMu(torch.nn.Module): """The difference between this class and MultivariateGaussianMDN_AvgMultiNoise is that only the mu and w of the mixture model are averaged.""" def __init__(self, node_in=100, node_out=2, hidden_layer=3, comp_n=3, nodes=None, activation_func='Softplus'): super(MultivariateGaussianMDN_AvgMu, self).__init__() self.node_in = node_in self.node_out = node_out self.comp_n = comp_n if nodes is None: nodes = nodeframe.decreasingNode(node_in=node_in, node_out=comp_n+node_out*comp_n*2+comp_n*(node_out**2-node_out)//2, hidden_layer=hidden_layer, get_allNode=True) self.fcnet = seq.LinearSeq(nodes, mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq() #need further research
[docs] def forward(self, x, multi_noise=1): #the setting multi_noise=1 here is used for the prediction process #method 1 x = self.fcnet(x) omega = x[:,:self.comp_n] # omega = nn.Softmax(dim=1)(omega) #test, mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] mu = mu.view(-1, self.comp_n, self.node_out, 1) cholesky_diag = x[:,self.comp_n+self.node_out*self.comp_n:self.comp_n+self.node_out*self.comp_n*2] cholesky_diag = nn.Softplus()(cholesky_diag) cholesky_diag = cholesky_diag.view(-1, self.comp_n, self.node_out) cholesky_factor = torch.diag_embed(cholesky_diag) cholesky_offDiag = x[:,self.comp_n+self.node_out*self.comp_n*2:] cholesky_offDiag = cholesky_offDiag.view(-1, self.comp_n, (self.node_out**2-self.node_out)//2) upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1) cholesky_factor[:,:, upper_index[0], upper_index[1]] = cholesky_offDiag if multi_noise>1: omega_chunk = torch.chunk(omega, multi_noise, dim=0) omega = omega_chunk[0] mu_chunk = torch.chunk(mu, multi_noise, dim=0) mu = mu_chunk[0] # cholesky_factor_chunk = torch.chunk(cholesky_factor, multi_noise, dim=0) # cholesky_factor = cholesky_factor_chunk[0] for i in range(multi_noise-1): omega = omega + omega_chunk[i+1] mu = mu + mu_chunk[i+1] # cholesky_factor = cholesky_factor + cholesky_factor_chunk[i+1] omega = omega / multi_noise omega = omega.repeat(multi_noise, 1)# mu = mu / multi_noise mu = mu.repeat(multi_noise, 1, 1, 1)# # cholesky_factor = cholesky_factor / torch.sqrt(torch.tensor(multi_noise)) # cholesky_factor = cholesky_factor.repeat(multi_noise, 1, 1, 1) omega = nn.Softmax(dim=1)(omega) return omega, mu, cholesky_factor
[docs]def multivariateGaussian_PDF(mu, cholesky_factor, target, log=True): diff = target - mu params_n = cholesky_factor.size(-1) sqrt_2pi = torch.sqrt(torch.tensor(2*np.pi)**params_n) if log: #learn Cholesky factor, here cholesky_factor is Cholesky factor of the inverse covariance matrix #see arXiv:2003.05739 log_det_2 = torch.sum(torch.log(torch.diagonal(cholesky_factor, dim1=2, dim2=3)), dim=2) comb = torch.matmul(cholesky_factor, diff) prob = -0.5*torch.matmul(comb.transpose(2,3), comb)[:,:,0,0] + log_det_2 - torch.log(sqrt_2pi) #note: cov_mul[:,:,0,0] else: det_sqrt = torch.prod(torch.diagonal(cholesky_factor, dim1=2, dim2=3), dim=2) comb = torch.matmul(cholesky_factor, diff) prob = torch.exp(-0.5*torch.matmul(comb.transpose(2,3), comb)[:,:,0,0]) * det_sqrt / sqrt_2pi return prob
[docs]def multivariateGaussian_loss(omega, mu, cholesky_factor, target, logsumexp=True): """Calculates the error, given the MoG parameters and the target The loss is the negative log likelihood of the data given the MoG parameters. ``torch.logsumexp`` will help us to avoid a lot of numerical instabilities in training see: https://deep-and-shallow.com/2021/03/20/mixture-density-networks-probabilistic-regression-for-uncertainty-estimation/ """ target = target.unsqueeze(1).unsqueeze(-1).expand_as(mu) if logsumexp: log_prob = multivariateGaussian_PDF(mu, cholesky_factor, target, log=logsumexp) log_omega = torch.log(omega) prob = -torch.logsumexp(log_omega+log_prob, dim=1) #dim=1 means sum for comp_n dimension else: prob = omega * multivariateGaussian_PDF(mu, cholesky_factor, target, log=logsumexp) prob = -torch.log(torch.sum(prob, dim=1)) #sum for comp_n dimension return torch.mean(prob)
[docs]def multivariateGaussian_sampler(omega, mu, cholesky_factor, chain_leng=10000): cov_inv = torch.matmul(cholesky_factor.transpose(2,3), cholesky_factor) cov_true = torch.inverse(cov_inv) try: L = torch.cholesky(cov_true, upper=False) #cov=LL^T except RuntimeError: warnings.warn('It is failed when using torch.cholesky and no ANN chain is obtained, because the covariance matrix is not possitive definite!') return None omega = omega.expand([chain_leng] + list(omega.size())[1:]) mu = mu.expand([chain_leng] + list(mu.size())[1:]) L = L.expand([chain_leng] + list(L.size())[1:]) omegas = Categorical(omega).sample().view(chain_leng, 1, 1, 1) mus = mu.detach().gather(1, omegas.expand(chain_leng, 1, mu.size(2), mu.size(3))).squeeze() Ls = L.detach().gather(1, omegas.expand(chain_leng, 1, L.size(2), L.size(3))).detach().squeeze() samples = torch.distributions.multivariate_normal.MultivariateNormal(mus, scale_tril=Ls).sample() #use scale_tril is more efficient return samples
#%% (Multivariate) Gaussian Mixture Density Network - for multiple data sets & one (multiple) parameter
[docs]class MultiBranchGaussianMDN(nn.Module): def __init__(self, nodes_in=[100,100,100], node_out=2, branch_hiddenLayer=1, trunk_hiddenLayer=1, comp_n=3, nodes_all=None, activation_func='Softplus'): super(MultiBranchGaussianMDN, self).__init__() self.nodes_in = nodes_in self.node_out = node_out self.comp_n = comp_n if nodes_all is None: nodes_all = [] branches_out = [] fc_hidden = branch_hiddenLayer*2 + 1 # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary fc_out = comp_n+node_out*comp_n*2 for i in range(len(nodes_in)): fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True) nodes_branch = fc_node[:branch_hiddenLayer+2] nodes_all.append(nodes_branch) branches_out.append(nodes_branch[-1]) nodes_all.append(nodeframe.decreasingNode(node_in=sum(branches_out), node_out=fc_out, hidden_layer=trunk_hiddenLayer, get_allNode=True)) self.branch_n = len(nodes_in) for i in range(self.branch_n): exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,\ finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1)) self.trunk = seq.LinearSeq(nodes_all[self.branch_n],mainActive=activation_func,finalActive='None',mainBN=True, finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all): x1 = self.branch1(x_all[0]) x_comb = x1 for i in range(1, self.branch_n-1+1): x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1 x_comb = torch.cat((x_comb, x_n),1) x = self.trunk(x_comb) omega = nn.Softmax(dim=1)(x[:,:self.comp_n]) mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] mu = mu.view(-1, self.comp_n, self.node_out) sigma = nn.Softplus()(x[:,self.comp_n+self.node_out*self.comp_n:]) sigma = sigma.view(-1, self.comp_n, self.node_out) return omega, mu, sigma
[docs]class MultiBranchMultivariateGaussianMDN(nn.Module): def __init__(self, nodes_in=[100,100,100], node_out=2, branch_hiddenLayer=1, trunk_hiddenLayer=1, comp_n=3, nodes_all=None, activation_func='Softplus'): super(MultiBranchMultivariateGaussianMDN, self).__init__() self.nodes_in = nodes_in self.node_out = node_out self.comp_n = comp_n if nodes_all is None: nodes_all = [] branches_out = [] fc_hidden = branch_hiddenLayer*2 + 1 # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary fc_out = comp_n+node_out*comp_n*2+comp_n*(node_out**2-node_out)//2 for i in range(len(nodes_in)): fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True) nodes_branch = fc_node[:branch_hiddenLayer+2] nodes_all.append(nodes_branch) branches_out.append(nodes_branch[-1]) nodes_all.append(nodeframe.decreasingNode(node_in=sum(branches_out), node_out=fc_out, hidden_layer=trunk_hiddenLayer, get_allNode=True)) self.branch_n = len(nodes_in) for i in range(self.branch_n): exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,\ finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1)) self.trunk = seq.LinearSeq(nodes_all[self.branch_n],mainActive=activation_func,finalActive='None',mainBN=True, finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all): x1 = self.branch1(x_all[0]) x_comb = x1 for i in range(1, self.branch_n-1+1): x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1 x_comb = torch.cat((x_comb, x_n),1) x = self.trunk(x_comb) omega = nn.Softmax(dim=1)(x[:,:self.comp_n]) mu = x[:,self.comp_n:self.comp_n+self.node_out*self.comp_n] mu = mu.view(-1, self.comp_n, self.node_out, 1) cholesky_diag = nn.Softplus()(x[:,self.comp_n+self.node_out*self.comp_n:self.comp_n+self.node_out*self.comp_n*2]) cholesky_diag = cholesky_diag.view(-1, self.comp_n, self.node_out) cholesky_factor = torch.diag_embed(cholesky_diag) cholesky_offDiag = x[:,self.comp_n+self.node_out*self.comp_n*2:] cholesky_offDiag = cholesky_offDiag.view(-1, self.comp_n, (self.node_out**2-self.node_out)//2) upper_index = torch.triu_indices(self.node_out, self.node_out, offset=1) cholesky_factor[:,:, upper_index[0], upper_index[1]] = cholesky_offDiag return omega, mu, cholesky_factor
#%%Beta Mixture Density Network - for one data sets & one parameter
[docs]class BetaMDN(torch.nn.Module): def __init__(self, node_in=100, node_out=1, hidden_layer=3, comp_n=3, nodes=None, activation_func='Softplus'): super(BetaMDN, self).__init__() self.node_in = node_in self.node_out = node_out self.comp_n = comp_n if nodes is None: # each parameter has independent omega nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out*comp_n*3, hidden_layer=hidden_layer, get_allNode=True) self.fcnet = seq.LinearSeq(nodes,mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x): x = self.fcnet(x) omega = nn.Softmax(dim=1)(x[:,:self.node_out*self.comp_n]) omega = omega.view(-1, self.node_out, self.comp_n) alpha = nn.Softplus()(x[:,self.node_out*self.comp_n:self.node_out*self.comp_n*2]) alpha = alpha.view(-1, self.node_out, self.comp_n) beta = nn.Softplus()(x[:,self.node_out*self.comp_n*2:]) beta = beta.view(-1, self.node_out, self.comp_n) return omega, alpha, beta
[docs]def beta_PDF(alpha, beta, target, log=True): """ https://zh.wikipedia.org/wiki/%CE%92%E5%88%86%E5%B8%83 return: x^(alpha-1) (1-x)^(beta-1) Gamma(alpha+beta) / Gamma(alpha)/Gamma(beta), where Gamma is the Gamma function Note: target > 0 & 1-target > 0, so, 0 < target < 1, so, should use minmax normalization """ #Note: torch.lgamma is \ln\Gamma(|x|), it equals to \ln\Gamma(x) only for x>0 if log: prob = (alpha-1)*torch.log(target) + (beta-1)*torch.log(1-target) + torch.lgamma(alpha+beta) - torch.lgamma(alpha) - torch.lgamma(beta) else: prob = (alpha-1)*torch.log(target) + (beta-1)*torch.log(1-target) + torch.lgamma(alpha+beta) - torch.lgamma(alpha) - torch.lgamma(beta) prob = torch.exp(prob) # prob = target**(alpha-1) + (1-target)**(beta-1) * torch.exp(torch.lgamma(alpha+beta)) / torch.exp(torch.lgamma(alpha)) / torch.exp(torch.lgamma(beta)) return prob
[docs]def beta_loss(omega, alpha, beta, target, logsumexp=True): target = target.unsqueeze(2).expand_as(alpha) if logsumexp: log_prob = beta_PDF(alpha, beta, target, log=logsumexp) log_omega = torch.log(omega) prob = torch.logsumexp(log_omega+log_prob, dim=2) #dim=2 means sum for comp_n dimension prob = -torch.sum(prob, dim=1) #sum ln(PDF) of all parameters else: prob = omega * beta_PDF(alpha, beta, target, log=logsumexp) prob = torch.sum(prob, dim=2) #sum for comp_n dimension prob = -torch.log(torch.prod(prob, dim=1)) #product of PDF of all parameters return torch.mean(prob)
[docs]def beta_sampler(omega, alpha, beta, chain_leng=10000): omega = omega.expand([chain_leng] + list(omega.size())[1:]) alpha = alpha.expand([chain_leng] + list(alpha.size())[1:]) beta = beta.expand([chain_leng] + list(beta.size())[1:]) omegas = Categorical(omega).sample().view(omega.size(0), omega.size(1), 1) alphas = alpha.detach().gather(2, omegas).squeeze() betas = beta.gather(2, omegas).detach().squeeze() samples_uncorr = torch.distributions.beta.Beta(alphas, betas).sample() return samples_uncorr
#%%Beta Mixture Density Network - for multiple data sets & one parameter
[docs]class MultiBranchBetaMDN(nn.Module): def __init__(self, nodes_in=[100,100,100], node_out=2, branch_hiddenLayer=1, trunk_hiddenLayer=1, comp_n=3, nodes_all=None, activation_func='Softplus'): super(MultiBranchBetaMDN, self).__init__() self.nodes_in = nodes_in self.node_out = node_out self.comp_n = comp_n if nodes_all is None: nodes_all = [] branches_out = [] fc_hidden = branch_hiddenLayer*2 + 1 # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary fc_out = node_out*comp_n*3 for i in range(len(nodes_in)): fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=fc_out, hidden_layer=fc_hidden, get_allNode=True) nodes_branch = fc_node[:branch_hiddenLayer+2] nodes_all.append(nodes_branch) branches_out.append(nodes_branch[-1]) nodes_all.append(nodeframe.decreasingNode(node_in=sum(branches_out), node_out=fc_out, hidden_layer=trunk_hiddenLayer, get_allNode=True)) self.branch_n = len(nodes_in) for i in range(self.branch_n): exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,\ finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1)) self.trunk = seq.LinearSeq(nodes_all[self.branch_n],mainActive=activation_func,finalActive='None',mainBN=True, finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all): x1 = self.branch1(x_all[0]) x_comb = x1 for i in range(1, self.branch_n-1+1): x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1 x_comb = torch.cat((x_comb, x_n),1) x = self.trunk(x_comb) omega = nn.Softmax(dim=1)(x[:,:self.node_out*self.comp_n]) omega = omega.view(-1, self.node_out, self.comp_n) alpha = nn.Softplus()(x[:,self.node_out*self.comp_n:self.node_out*self.comp_n*2]) alpha = alpha.view(-1, self.node_out, self.comp_n) beta = nn.Softplus()(x[:,self.node_out*self.comp_n*2:]) beta = beta.view(-1, self.node_out, self.comp_n) return omega, alpha, beta
#%% loss function & sampler
[docs]def loss_funcs(comp_type, params_n): if comp_type=='Gaussian': if params_n==1: return gaussian_loss else: return multivariateGaussian_loss elif comp_type=='Beta': return beta_loss
[docs]def samplers(comp_type, params_n): if comp_type=='Gaussian': if params_n==1: return gaussian_sampler else: return multivariateGaussian_sampler elif comp_type=='Beta': return beta_sampler
#%% Branch network
[docs]class Branch(nn.Module): def __init__(self,): super(Branch, self).__init__() pass