import os
import shutil
import subprocess
import sys

exctinp_file = sys.argv[1]
basename = exctinp_file.split('_')[0]
xyz_fname = basename + '_EXT.xyz'
with open(xyz_fname, 'r') as infile:
    n_atoms = int(infile.readline().strip())

### Write input and run the HUMMR calc 
hummr_input_fname = sys.argv[2]
hummr_output_fname = hummr_input_fname.split('.')[0] + '.out'

with open(xyz_fname, 'r') as infile:
	coords_lines = infile.readlines()
	coords_lines = coords_lines[2:]

with open(hummr_input_fname, 'r') as infile:
	hummr_input_lines_raw = infile.readlines()
	hummr_input_lines = []
	for line in hummr_input_lines_raw:
		hummr_input_lines.append(line)
		if 'geom' in line.lower():
			break

	for line in coords_lines:
		hummr_input_lines.append(line)
	hummr_input_lines.append('End\n')

with open(hummr_input_fname, 'w') as outfile:
	for line in hummr_input_lines:
		outfile.write(line)

# If the previous orbital file exists, copy it to the input 	
dst_orbs_fname = 'inporbs.C0'
src_orbs_fname = hummr_input_fname.split('.')[0] + '.C0'
if (os.path.isfile(src_orbs_fname)):
	shutil.copyfile(src_orbs_fname, dst_orbs_fname)

num_threads = int(sys.argv[3])
num_procs = int(sys.argv[4])
subprocess.check_output('export OMP_NUM_THREADS=%d; mpirun --bind-to none -n %d hummr %s > %s 2>&1' %
						(num_threads, num_procs, hummr_input_fname, hummr_output_fname), shell=True)

### Process the HUMMR output file
energies = subprocess.check_output(
                "grep --binary-files=text 'Total energy =' %s" % (hummr_output_fname), shell=True).decode('utf-8')
energies = energies.splitlines()

n_roots = int(energies[-1].split()[1].strip(':'))
try:
	iroot_line = subprocess.check_output("grep -i 'calcsagrad' %s" % hummr_input_fname, 
										  shell=True).decode('utf-8')
except:
	iroot_line = subprocess.check_output("grep -i 'calcssgrad' %s" % hummr_input_fname, 
										  shell=True).decode('utf-8')
iroot = int(iroot_line.split()[-1])
energy = float(energies[-1 - n_roots + iroot].split()[-1])

gradient_flat = []
with open(hummr_output_fname, 'r') as infile:
	lines = infile.readlines()

	lines_gradient = []
	for idx, line in enumerate(lines):
		if 'Nuclear gradient' in line:
			lines_gradient = lines[idx + 1: idx + n_atoms + 1]
			break		

	for line in lines_gradient:
		line_s = line.split()
		gradient_flat.extend((line_s[1], line_s[2], line_s[3]))	

### Write the engrad-file for Orca
engrad_file_text = '#\n' + '# Number of atoms\n' + '#\n'
engrad_file_text += str(n_atoms) + '\n'
engrad_file_text += '#\n' + '# The current total energy in Eh\n' + '#\n'
engrad_file_text += str(energy) + '\n'
engrad_file_text += '#\n' + '# The current gradient in Eh/bohr\n' + '#\n'

for val in gradient_flat:
	engrad_file_text += val + '\n'

engrad_fname = basename + '_EXT.engrad'
with open (engrad_fname, 'w') as outfile:
	outfile.write(engrad_file_text)
