Source code for colfi.utils

# -*- coding: utf-8 -*-

import sys
import os
import torch
import math
import pandas
import numpy as np


#%%
[docs]class LrDecay: """Let the learning rate decay with iteration. """ def __init__(self, iter_mid, iteration=10000, lr=0.1, lr_min=1e-6): self.lr = lr self.lr_min = lr_min self.iter_mid = iter_mid self.iteration = iteration
[docs] def exp(self, gamma=0.999, auto_params=True): """Exponential decay. Parameters ---------- auto_params : bool If True, gamma is set automatically. Returns ------- float lr * gamma^iteration """ if auto_params: gamma = (self.lr_min/self.lr)**(1./self.iteration) lr_new = self.lr * gamma**self.iter_mid return lr_new
[docs] def step(self, stepsize=1000, gamma=0.3, auto_params=True): """Let the learning rate decays step by step, similar to 'exp'. """ if auto_params: gamma = (self.lr_min/self.lr)**(1./(self.iteration*1.0/stepsize)) lr_new = self.lr * gamma**(math.floor(self.iter_mid*1.0/stepsize)) return lr_new
[docs] def poly(self, decay_step=500, power=0.999, cycle=True): """Polynomial decay. Parameters ---------- Returns ------- float (lr-lr_min) * (1 - iteration/decay_steps)^power +lr_min """ if cycle: decay_steps = decay_step * math.ceil(self.iter_mid*1.0/decay_step) else: decay_steps = self.iteration lr_new = (self.lr-self.lr_min) * (1 - self.iter_mid*1.0/decay_steps)**power + self.lr_min return lr_new
#%%
[docs]def makeList(roots): """Checks if the given parameter is a list. Parameters ---------- roots : object The parameter to check. If it is not a list, creates a list with the parameter as an item in it. Returns ------- list A list containing the parameter. """ if isinstance(roots, (list, tuple)): return roots else: return [roots]
#%%save files
[docs]def mkdir(path): """Make a directory in a particular location if it is not exists. Parameters ---------- path : str The path of a file. Examples -------- >>> mkdir('/home/UserName/test') >>> mkdir('test/one') >>> mkdir('../test/one') """ #remove the blank space in the before and after strings #path.strip() is used to remove the characters in the beginning and the end of the character string # path = path.strip() #remove all blank space in the strings, there is no need to use path.strip() when using this command path = path.replace(' ', '') #path.rstrip() is used to remove the characters in the right of the characters strings if path=='': raise ValueError('The path cannot be an empty string') path = path.rstrip("/") isExists = os.path.exists(path) if not isExists: os.makedirs(path, exist_ok=True) #exist_ok=True print('The directory "%s" is successfully created !'%path) return True else: # print('The directory "%s" is already exists!'%path) # return False pass
[docs]def savetxt(path, FileName, File): """Save the .txt files using :func:`numpy.savetxt()` funtion. Parameters ---------- path : str The path of the file to be saved. FileName : str The name of the file to be saved. File : object The file to be saved. """ mkdir(path) np.savetxt(path + '/' + FileName + '.txt', File)
[docs]def savenpy(path, FileName, File, dtype=np.float32): """Save an array to a binary file in .npy format using :func:`numpy.save()` function. Parameters ---------- path : str The path of the file to be saved. FileName : str The name of the file to be saved. File : object The file to be saved. dtype : str or object The type of the data to be saved. Default: ``numpy.float32``. """ mkdir(path) #dtype=object works for saving hparams if type(File) is np.ndarray and dtype is not object: File = File.astype(dtype) np.save(path + '/' + FileName + '.npy', File)
[docs]def saveTorchPt(path, FileName, File): """Save the .pt files using :func:`torch.save()` funtion. Parameters ---------- path : str The path of the file to be saved. FileName : str The name of the file to be saved. File : object The file to be saved. """ mkdir(path) torch.save(File, path + '/' + FileName)
#%% get file path
[docs]class FilePath: def __init__(self, filedir='ann', randn_num='', suffix='.pt', separator='_', raise_err=True): """Obtain the path of a specific file. Parameters ---------- filedir : str The relative path of a file. randn_num : str or float A random number that owned by a file name. suffix : str The suffix of the file, e.g. '.npy', '.pt' separator : str Symbol for splitting the random number in the file name. """ self.filedir = filedir self.randn_num = str(randn_num) self.separator = separator self.file_suffix = suffix self.raise_err = raise_err
[docs] def filePath(self): listdir = os.listdir(self.filedir) for File in listdir: if File.endswith(self.file_suffix): fileName = os.path.splitext(File)[0] randn = fileName.split(self.separator)[-1] if randn == self.randn_num: target_file = self.filedir + '/' + File if 'target_file' not in dir(): if self.raise_err: raise IOError('No eligible files with randn_num: %s ! in %s'%(self.randn_num, self.filedir)) else: return None return target_file
#%% redirect output
[docs]class Logger(object): """Record the output of the terminal and write it to disk. """ def __init__(self, path='logs', fileName="log", stream=sys.stdout): self.terminal = stream self.path = path self.fileName = fileName self._log() def _log(self): if self.path: mkdir(self.path) self.log = open(self.path+'/'+self.fileName+'.log', "w") else: self.log = open(self.fileName+'.log', "w")
[docs] def write(self, message): self.terminal.write(message) self.log.write(message) self.terminal.flush() self.log.flush()
[docs] def flush(self): pass
[docs]def logger(path='logs', fileName='log'): sys.stdout = Logger(path=path, fileName=fileName, stream=sys.stdout) sys.stderr = Logger(path=path, fileName=fileName, stream=sys.stderr) # redirect std err, if necessary
#%% save predict_*.py
[docs]def get_randn_suffix(randn_num=1.234): return ''.join(str(randn_num).split('.'))
[docs]def save_predict(path='ann', nde_type='ANN', randn_num=1.123, file_identity_str='', chain_true_path='', label_true='True', fiducial_params=[]): # randn_suffix = get_randn_suffix(randn_num=randn_num) # file_name = 'predict_%s.py'%(randn_suffix) file_name = 'predict%s_%s_%s.py'%(file_identity_str, nde_type.lower(), randn_num) file_path = os.getcwd() + '/' + file_name with open(file_path, 'a') as f: f.write('''\ import sys sys.path.append('..') sys.path.append('../..') import colfi.nde as nde import matplotlib.pyplot as plt predictor = nde.Predict(path='%s', randn_num=%s) '''%(path, randn_num)) if chain_true_path and label_true: with open(file_path, 'a') as f: f.write('''\ predictor.chain_true_path = '%s' predictor.label_true = '%s' '''%(chain_true_path, label_true)) if len(fiducial_params)!=0: with open(file_path, 'a') as f: f.write('''\ predictor.fiducial_params = %s '''%(fiducial_params)) with open(file_path, 'a') as f: f.write('''\ predictor.from_chain() #estimate parameters using the saved chains predictor.get_steps() #plot the estimated parameters at each step predictor.get_contour() #plot contours of the estimated parameters predictor.get_losses() #plot the losses of the training set or/and the validation set plt.show() ''') os.chmod(file_path, 0o777)
#%% #to be updated ?
[docs]def remove_nan(obs, params): """Remove the 'nan' in the numpy array, used for the simulated observations. Parameters ---------- obs : array-like The simulated observations, Numpy array with one or multi dimension. params : array-like The simulated parameters, Numpy array with one or multi dimension. Returns ------- obs_new : array-like The new observations that do not contain nan. params_new : array-like The new parameters that do not contain nan. """ idx_nan = np.where(np.isnan(obs))[0] if len(idx_nan)==0: print("There are no 'nan' in the mock data.") return obs, params idx_good = np.where(~np.isnan(obs))[0] idx_nan = np.unique(idx_nan) idx_good = np.unique(idx_good) idx_nan_pandas = pandas.Index(idx_nan) idx_good_pandas = pandas.Index(idx_good) idx_good_pandas = idx_good_pandas.difference(idx_nan_pandas, sort=False) idx_good = idx_good_pandas.to_numpy() obs_new = obs[idx_good] params_new = params[idx_good] return obs_new, params_new