# -*- coding: utf-8 -*-
from . import sequence as seq
from . import nodeframe
import torch
import torch.nn as nn
#%% fully connected single network
[docs]class FcNet(torch.nn.Module):
"""Get a fully connected network.
Parameters
----------
node_in : int
The number of the input nodes.
node_out : int
The number of the output nodes.
hidden_layer : int
The number of the hidden layers.
nodes : None or list, optional
If list, it should be a collection of nodes of the network,
e.g. [node_in, node_hidden1, node_hidden2, ..., node_out]
activation_func : str, optional
Activation function. See :func:`~.element.activation`. Default: 'RReLU'
"""
def __init__(self, node_in=2000, node_out=6, hidden_layer=3, nodes=None,
activation_func='RReLU'):
super(FcNet, self).__init__()
if nodes is None:
nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out, hidden_layer=hidden_layer, get_allNode=True)
self.fc = seq.LinearSeq(nodes, mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq()
[docs] def forward(self, x):
x = self.fc(x)
return x
#%% multibranch network
[docs]def split_nodes(nodes, weight=[]):
nodes_new = [[] for i in range(len(weight))]
for i in range(len(weight)):
for j in range(len(nodes)):
nodes_new[i].append(round(nodes[j]*weight[i]))
return nodes_new
[docs]class MultiBranchFcNet(nn.Module):
"""Get a multibranch network.
Parameters
----------
nodes_in : list
The number of the input nodes for each branch.
e.g. [node_in_branch1, node_in_branch2, ...]
node_out : int
The number of the output nodes.
branch_hiddenLayer : int
The number of the hidden layers for the branch part.
trunk_hiddenLayer : int
The number of the hidden layers for the trunk part.
nodes_all : list, optional
The number of nodes of the multibranch network.
e.g. [nodes_branch1, nodes_branch2, ..., nodes_trunk]
activation_func : str, optional
Activation function. See :func:`~.element.activation`. Default: 'RReLU'
"""
def __init__(self, nodes_in=[100,100,20], node_out=6, branch_hiddenLayer=1,
trunk_hiddenLayer=3, nodes_all=None, activation_func='RReLU'):
super(MultiBranchFcNet, self).__init__()
if nodes_all is None:
# method 1
nodes_all = []
branch_outs = []
fc_hidden = branch_hiddenLayer*2 + 1
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary
for i in range(len(nodes_in)):
fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=node_out, hidden_layer=fc_hidden, get_allNode=True)
branch_node = fc_node[:branch_hiddenLayer+2]
nodes_all.append(branch_node)
branch_outs.append(branch_node[-1])
nodes_all.append(nodeframe.decreasingNode(node_in=sum(branch_outs), node_out=node_out, hidden_layer=trunk_hiddenLayer, get_allNode=True))
# #method 2
# nodes_all = []
# branch_outs = []
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1
# fc_hidd_node = nodeframe.decreasingNode(node_in=sum(nodes_in), node_out=node_out, hidden_layer=fc_hidden, get_allNode=False)
# fc_hidd_node_split = split_nodes(fc_hidd_node[:branch_hiddenLayer+1], weight=[nodes_in[i]/sum(nodes_in) for i in range(len(nodes_in))])
# for i in range(len(nodes_in)):
# branch_node = [nodes_in[i]] + fc_hidd_node_split[i]
# nodes_all.append(branch_node)
# branch_outs.append(branch_node[-1])
# trunk_node = [sum(branch_outs)] + list(fc_hidd_node[branch_hiddenLayer+1:]) + [node_out]
# nodes_all.append(trunk_node)
# #method 3
# nodes_all = []
# nodes_comb = []
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1
# for i in range(len(nodes_in)):
# fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=node_out, hidden_layer=fc_hidden, get_allNode=True)
# branch_node = fc_node[:branch_hiddenLayer+2]
# nodes_all.append(branch_node)
# nodes_comb.append(fc_node[branch_hiddenLayer+1:-1])
# trunk_node = list(np.sum(np.array(nodes_comb), axis=0)) + [node_out]
# nodes_all.append(trunk_node)
self.branch_n = len(nodes_all) - 1
for i in range(self.branch_n):
exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1))
self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all):
x1 = self.branch1(x_all[0])
x_comb = x1
for i in range(1, self.branch_n-1+1):
x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1
x_comb = torch.cat((x_comb, x_n),1)
x = self.trunk(x_comb)
return x
#remove?
[docs]class MultiBranchFcNet_test(nn.Module):
"""Get a multibranch network.
Parameters
----------
nodes_in : list
The number of the input nodes for each branch.
e.g. [node_in_branch1, node_in_branch2, ...]
node_out : int
The number of the output nodes.
branch_hiddenLayer : int
The number of the hidden layers for the branch part.
trunk_hiddenLayer : int
The number of the hidden layers for the trunk part.
nodes_all : list, optional
The number of nodes of the multibranch network.
e.g. [nodes_branch1, nodes_branch2, ..., nodes_trunk]
activation_func : str, optional
Activation function. See :func:`~.element.activation`. Default: 'RReLU'
"""
def __init__(self, nodes_in=[100,100,20], node_out=6, branch_hiddenLayer=1,
trunk_hiddenLayer=3, nodes_all=None, activation_func='RReLU'):
super(MultiBranchFcNet_test, self).__init__()
if nodes_all is None:
# method 1
nodes = nodeframe.decreasingNode(node_in=sum(nodes_in), node_out=node_out, hidden_layer=branch_hiddenLayer+trunk_hiddenLayer, get_allNode=True)
# nodes_trunk = nodes[branch_hiddenLayer+2:]
nodes_trunk = nodes[branch_hiddenLayer+1:]
node_mid = nodes_trunk[0]
nodes_all = []
for i in range(len(nodes_in)):
branch_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=node_mid, hidden_layer=branch_hiddenLayer, get_allNode=True)
nodes_all.append(branch_node)
nodes_all.append(nodes_trunk)
self.branch_n = len(nodes_all) - 1
for i in range(self.branch_n):
exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1))
self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all):
x_sum = self.branch1(x_all[0])
for i in range(1, self.branch_n-1+1):
x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1
x_sum = x_sum + x_n
x = self.trunk(x_sum)
return x
#%% for annmc
[docs]class MultiBranchFcNet_MC(nn.Module):
def __init__(self, node_in=6, nodes_out=[100,100,20], trunk_hiddenLayer=3,
branch_hiddenLayer=1, nodes_all=None, activation_func='RReLU'):
super(MultiBranchFcNet_MC, self).__init__()
if nodes_all is None:
# method 1
nodes_all = []
self.branch_ins = []
fc_hidden = branch_hiddenLayer*2 + 1
# fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary
for i in range(len(nodes_out)):
fc_node = nodeframe.decreasingNode(node_in=node_in, node_out=nodes_out[i], hidden_layer=fc_hidden, get_allNode=True)
node_idx = len(fc_node) - (branch_hiddenLayer+2)
branch_node = fc_node[node_idx:]
nodes_all.append(branch_node)
self.branch_ins.append(branch_node[0])
nodes_all.append(nodeframe.decreasingNode(node_in=node_in, node_out=sum(self.branch_ins), hidden_layer=trunk_hiddenLayer, get_allNode=True))
self.branch_n = len(nodes_all) - 1
self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive=activation_func,mainBN=True,finalBN=True,mainDropout='None',finalDropout='None').get_seq()
for i in range(self.branch_n):
exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()"%(i+1))
[docs] def forward(self, x):
x_mid = self.trunk(x)
x_out = []
for i in range(self.branch_n):
x_n = eval('self.branch%s(x_mid[:, sum(self.branch_ins[:i]):sum(self.branch_ins[:i+1])])'%(i+1))
x_out.append(x_n)
return x_out
#%%
[docs]def loss_funcs(name='L1'):
"""Some loss functions.
Parameters
----------
name : str, optional
Abbreviation of loss function name, which can be 'L1', 'MSE', or 'SmoothL1'. Default: 'L1'.
Returns
-------
object
The corresponding loss function.
"""
if name=='L1':
lf = torch.nn.L1Loss()
elif name=='MSE':
lf = torch.nn.MSELoss()
elif name=='SmoothL1':
lf = torch.nn.SmoothL1Loss()
return lf