Source code for pyscf.mcscf.mc_ao2mo

#!/usr/bin/env python
# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes
import time
from functools import reduce
import numpy
import h5py
from pyscf import lib
from pyscf.lib import logger
from pyscf import ao2mo
from pyscf.ao2mo import _ao2mo
from pyscf.ao2mo import outcore

# least memory requirements:
# nmo  ncore  ncas  outcore  incore
# 200  40     16    0.8GB    3.7 GB (_eri 1.6GB intermediates 1.3G)
# 250  50     16    1.7GB    8.2 GB (_eri 3.9GB intermediates 2.6G)
# 300  60     16    3.1GB    16.8GB (_eri 8.1GB intermediates 5.6G)
# 400  80     16    8.5GB    53  GB (_eri 25.6GB intermediates 19G)
# 500  100    16    19 GB
# 600  120    16    37 GB
# 750  150    16    85 GB

libmcscf = lib.load_library('libmcscf')

[docs]def trans_e1_incore(eri_ao, mo, ncore, ncas): nmo = mo.shape[1] nocc = ncore + ncas eri1 = ao2mo.incore.half_e1(eri_ao, (mo,mo[:,:nocc]), compact=False) eri1 = eri1.reshape(nmo,nocc,-1) klppshape = (0, nmo, 0, nmo) klpashape = (0, nmo, ncore, nocc) aapp = numpy.empty((ncas,ncas,nmo,nmo)) for i in range(ncas): _ao2mo.nr_e2(eri1[ncore+i,ncore:nocc], mo, klppshape, aosym='s4', mosym='s1', out=aapp[i]) ppaa = lib.transpose(aapp.reshape(ncas*ncas,-1)).reshape(nmo,nmo,ncas,ncas) aapp = None papa = numpy.empty((nmo,ncas,nmo,ncas)) for i in range(nmo): _ao2mo.nr_e2(eri1[i,ncore:nocc], mo, klpashape, aosym='s4', mosym='s1', out=papa[i]) pp = numpy.empty((nmo,nmo)) j_cp = numpy.zeros((ncore,nmo)) k_pc = numpy.zeros((nmo,ncore)) for i in range(ncore): _ao2mo.nr_e2(eri1[i,i:i+1], mo, klppshape, aosym='s4', mosym='s1', out=pp) j_cp[i] = pp.diagonal() j_pc = j_cp.T.copy() pp = numpy.empty((ncore,ncore)) for i in range(nmo): klshape = (i, i+1, 0, ncore) _ao2mo.nr_e2(eri1[i,:ncore], mo, klshape, aosym='s4', mosym='s1', out=pp) k_pc[i] = pp.diagonal() return j_pc, k_pc, ppaa, papa
# level = 1: ppaa, papa and jpc, kpc # level > 1: ppaa, papa only. It affects accuracy of hdiag
[docs]def trans_e1_outcore(mol, mo, ncore, ncas, erifile, max_memory=None, level=1, verbose=logger.WARN): time0 = (time.clock(), time.time()) log = logger.new_logger(mol, verbose) log.debug1('trans_e1_outcore level %d max_memory %d', level, max_memory) nao, nmo = mo.shape nao_pair = nao*(nao+1)//2 nocc = ncore + ncas faapp_buf = lib.H5TmpFile() if isinstance(erifile, h5py.Group): feri = erifile else: feri = lib.H5TmpFile(erifile, 'w') mo_c = numpy.asarray(mo, order='C') mo = numpy.asarray(mo, order='F') pashape = (0, nmo, ncore, nocc) papa_buf = numpy.zeros((nao,ncas,nmo*ncas)) j_pc = numpy.zeros((nmo,ncore)) k_pc = numpy.zeros((nmo,ncore)) mem_words = int(max(2000,max_memory-papa_buf.nbytes/1e6)*1e6/8) aobuflen = mem_words//(nao_pair+nocc*nmo) + 1 ao_loc = numpy.array(mol.ao_loc_nr(), dtype=numpy.int32) shranges = outcore.guess_shell_ranges(mol, True, aobuflen, None, ao_loc) intor = mol._add_suffix('int2e') ao2mopt = _ao2mo.AO2MOpt(mol, intor, 'CVHFnr_schwarz_cond', 'CVHFsetnr_direct_scf') nstep = len(shranges) maxbuflen = max([x[2] for x in shranges]) log.debug('mem_words %.8g MB, maxbuflen = %d', mem_words*8/1e6, maxbuflen) bufs1 = numpy.empty((maxbuflen, nao_pair)) bufs2 = numpy.empty((maxbuflen, nmo*ncas)) if level == 1: bufs3 = numpy.empty((maxbuflen, nao*ncore)) log.debug('mem cache %.8g MB', (bufs1.nbytes+bufs2.nbytes+bufs3.nbytes)/1e6) else: log.debug('mem cache %.8g MB', (bufs1.nbytes+bufs2.nbytes)/1e6) ti0 = log.timer('Initializing trans_e1_outcore', *time0) # fmmm, ftrans, fdrv for level 1 fmmm = libmcscf.AO2MOmmm_ket_nr_s2 ftrans = libmcscf.AO2MOtranse1_nr_s4 fdrv = libmcscf.AO2MOnr_e2_drv for istep,sh_range in enumerate(shranges): log.debug('[%d/%d], AO [%d:%d], len(buf) = %d', istep+1, nstep, *sh_range) buf = bufs1[:sh_range[2]] _ao2mo.nr_e1fill(intor, sh_range, mol._atm, mol._bas, mol._env, 's4', 1, ao2mopt, buf) if log.verbose >= logger.DEBUG1: ti1 = log.timer('AO integrals buffer', *ti0) bufpa = bufs2[:sh_range[2]] _ao2mo.nr_e1(buf, mo, pashape, 's4', 's1', out=bufpa) # jc_pp, kc_pp if level == 1: # ppaa, papa and vhf, jcp, kcp if log.verbose >= logger.DEBUG1: ti1 = log.timer('buffer-pa', *ti1) buf1 = bufs3[:sh_range[2]] fdrv(ftrans, fmmm, buf1.ctypes.data_as(ctypes.c_void_p), buf.ctypes.data_as(ctypes.c_void_p), mo.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(sh_range[2]), ctypes.c_int(nao), (ctypes.c_int*4)(0, nao, 0, ncore), ctypes.POINTER(ctypes.c_void_p)(), ctypes.c_int(0)) p0 = 0 for ij in range(sh_range[0], sh_range[1]): i,j = lib.index_tril_to_pair(ij) i0 = ao_loc[i] j0 = ao_loc[j] i1 = ao_loc[i+1] j1 = ao_loc[j+1] di = i1 - i0 dj = j1 - j0 if i == j: dij = di * (di+1) // 2 buf = numpy.empty((di,di,nao*ncore)) idx = numpy.tril_indices(di) buf[idx] = buf1[p0:p0+dij] buf[idx[1],idx[0]] = buf1[p0:p0+dij] buf = buf.reshape(di,di,nao,ncore) mo1 = mo_c[i0:i1] tmp = numpy.einsum('uvpc,pc->uvc', buf, mo[:,:ncore]) tmp = lib.dot(mo1.T, tmp.reshape(di,-1)) j_pc += numpy.einsum('vp,pvc->pc', mo1, tmp.reshape(nmo,di,ncore)) tmp = numpy.einsum('uvpc,uc->vcp', buf, mo1[:,:ncore]) tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(di,ncore,nmo) k_pc += numpy.einsum('vp,vcp->pc', mo1, tmp) else: dij = di * dj buf = buf1[p0:p0+dij].reshape(di,dj,nao,ncore) mo1 = mo_c[i0:i1] mo2 = mo_c[j0:j1] tmp = numpy.einsum('uvpc,pc->uvc', buf, mo[:,:ncore]) tmp = lib.dot(mo1.T, tmp.reshape(di,-1)) j_pc += numpy.einsum('vp,pvc->pc', mo2, tmp.reshape(nmo,dj,ncore)) * 2 tmp = numpy.einsum('uvpc,uc->vcp', buf, mo1[:,:ncore]) tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(dj,ncore,nmo) k_pc += numpy.einsum('vp,vcp->pc', mo2, tmp) tmp = numpy.einsum('uvpc,vc->ucp', buf, mo2[:,:ncore]) tmp = lib.dot(tmp.reshape(-1,nmo), mo).reshape(di,ncore,nmo) k_pc += numpy.einsum('up,ucp->pc', mo1, tmp) p0 += dij if log.verbose >= logger.DEBUG1: ti1 = log.timer('j_cp and k_cp', *ti1) if log.verbose >= logger.DEBUG1: ti1 = log.timer('half transformation of the buffer', *ti1) # ppaa, papa faapp_buf[str(istep)] = \ bufpa.reshape(sh_range[2],nmo,ncas)[:,ncore:nocc].reshape(-1,ncas**2).T p0 = 0 for ij in range(sh_range[0], sh_range[1]): i,j = lib.index_tril_to_pair(ij) i0 = ao_loc[i] j0 = ao_loc[j] i1 = ao_loc[i+1] j1 = ao_loc[j+1] di = i1 - i0 dj = j1 - j0 if i == j: dij = di * (di+1) // 2 buf1 = numpy.empty((di,di,nmo*ncas)) idx = numpy.tril_indices(di) buf1[idx] = bufpa[p0:p0+dij] buf1[idx[1],idx[0]] = bufpa[p0:p0+dij] else: dij = di * dj buf1 = bufpa[p0:p0+dij].reshape(di,dj,-1) mo1 = mo[j0:j1,ncore:nocc].copy() for i in range(di): lib.dot(mo1.T, buf1[i], 1, papa_buf[i0+i], 1) mo1 = mo[i0:i1,ncore:nocc].copy() buf1 = lib.dot(mo1.T, buf1.reshape(di,-1)) papa_buf[j0:j1] += buf1.reshape(ncas,dj,-1).transpose(1,0,2) p0 += dij if log.verbose >= logger.DEBUG1: ti1 = log.timer('ppaa and papa buffer', *ti1) ti0 = log.timer('gen AO/transform MO [%d/%d]'%(istep+1,nstep), *ti0) buf = buf1 = bufpa = None bufs1 = bufs2 = bufs3 = None time1 = log.timer('mc_ao2mo pass 1', *time0) log.debug1('Half transformation done. Current memory %d', lib.current_memory()[0]) nblk = int(max(8, min(nmo, (max_memory*1e6/8-papa_buf.size)/(ncas**2*nmo)))) log.debug1('nblk for papa = %d', nblk) dset = feri.create_dataset('papa', (nmo,ncas,nmo,ncas), 'f8') for i0, i1 in prange(0, nmo, nblk): tmp = lib.dot(mo[:,i0:i1].T, papa_buf.reshape(nao,-1)) dset[i0:i1] = tmp.reshape(i1-i0,ncas,nmo,ncas) papa_buf = tmp = None time1 = log.timer('papa pass 2', *time1) tmp = numpy.empty((ncas**2,nao_pair)) p0 = 0 for istep, sh_range in enumerate(shranges): tmp[:,p0:p0+sh_range[2]] = faapp_buf[str(istep)] p0 += sh_range[2] nblk = int(max(8, min(nmo, (max_memory*1e6/8-tmp.size)/(ncas**2*nmo)-1))) log.debug1('nblk for ppaa = %d', nblk) dset = feri.create_dataset('ppaa', (nmo,nmo,ncas,ncas), 'f8') for i0, i1 in prange(0, nmo, nblk): tmp1 = _ao2mo.nr_e2(tmp, mo, (i0,i1,0,nmo), 's4', 's1', ao_loc=ao_loc) tmp1 = tmp1.reshape(ncas,ncas,i1-i0,nmo) for j in range(i1-i0): dset[i0+j] = tmp1[:,:,j].transpose(2,0,1) tmp = tmp1 = None time1 = log.timer('ppaa pass 2', *time1) time0 = log.timer('mc_ao2mo', *time0) return j_pc, k_pc
# level = 1: ppaa, papa and vhf, jpc, kpc # level = 2: ppaa, papa, vhf, jpc=0, kpc=0 class _ERIS(object): def __init__(self, casscf, mo, method='incore', level=1): mol = casscf.mol nao, nmo = mo.shape ncore = casscf.ncore ncas = casscf.ncas dm_core = numpy.dot(mo[:,:ncore], mo[:,:ncore].T) vj, vk = casscf._scf.get_jk(mol, dm_core) self.vhf_c = reduce(numpy.dot, (mo.T, vj*2-vk, mo)) mem_incore, mem_outcore, mem_basic = _mem_usage(ncore, ncas, nmo) mem_now = lib.current_memory()[0] eri = casscf._scf._eri if (method == 'incore' and eri is not None and (mem_incore+mem_now < casscf.max_memory*.9) or mol.incore_anyway): if eri is None: eri = mol.intor('int2e', aosym='s8') self.j_pc, self.k_pc, self.ppaa, self.papa = \ trans_e1_incore(eri, mo, ncore, ncas) else: log = logger.Logger(casscf.stdout, casscf.verbose) self.feri = lib.H5TmpFile() max_memory = max(3000, casscf.max_memory*.9-mem_now) if max_memory < mem_basic: log.warn('Calculation needs %d MB memory, over CASSCF.max_memory (%d MB) limit', (mem_basic+mem_now)/.9, casscf.max_memory) self.j_pc, self.k_pc = \ trans_e1_outcore(mol, mo, ncore, ncas, self.feri, max_memory=max_memory, level=level, verbose=log) self.ppaa = self.feri['ppaa'] self.papa = self.feri['papa'] def _mem_usage(ncore, ncas, nmo): outcore = basic = ncas**2*nmo**2*2 * 8/1e6 incore = outcore + (ncore+ncas)*nmo**3*4/1e6 return incore, outcore, basic
[docs]def prange(start, end, step): for i in range(start, end, step): yield i, min(i+step, end)
if __name__ == '__main__': from pyscf import scf from pyscf import gto from pyscf.mcscf import mc1step mol = gto.Mole() mol.verbose = 0 mol.output = None#"out_h2o" mol.atom = [ ['O', ( 0., 0. , 0. )], ['H', ( 0., -0.757, 0.587)], ['H', ( 0., 0.757 , 0.587)],] mol.basis = {'H': 'cc-pvtz', 'O': 'cc-pvtz',} mol.build() m = scf.RHF(mol) ehf = m.scf() mc = mc1step.CASSCF(m, 6, 4) mc.verbose = 4 mo = m.mo_coeff.copy() eris0 = _ERIS(mc, mo, 'incore') eris1 = _ERIS(mc, mo, 'outcore') eris2 = _ERIS(mc, mo, 'outcore', level=1) eris3 = _ERIS(mc, mo, 'outcore', level=2) print('vhf_c', numpy.allclose(eris0.vhf_c, eris1.vhf_c)) print('j_pc ', numpy.allclose(eris0.j_pc , eris1.j_pc )) print('k_pc ', numpy.allclose(eris0.k_pc , eris1.k_pc )) print('ppaa ', numpy.allclose(eris0.ppaa , eris1.ppaa )) print('papa ', numpy.allclose(eris0.papa , eris1.papa )) print('vhf_c', numpy.allclose(eris0.vhf_c, eris2.vhf_c)) print('j_pc ', numpy.allclose(eris0.j_pc , eris2.j_pc )) print('k_pc ', numpy.allclose(eris0.k_pc , eris2.k_pc )) print('ppaa ', numpy.allclose(eris0.ppaa , eris2.ppaa )) print('papa ', numpy.allclose(eris0.papa , eris2.papa )) print('vhf_c', numpy.allclose(eris0.vhf_c, eris3.vhf_c)) print('ppaa ', numpy.allclose(eris0.ppaa , eris3.ppaa )) print('papa ', numpy.allclose(eris0.papa , eris3.papa )) ncore = mc.ncore ncas = mc.ncas nocc = ncore + ncas nmo = mo.shape[1] eri = ao2mo.incore.full(m._eri, mo, compact=False).reshape((nmo,)*4) aaap = numpy.array(eri[ncore:nocc,ncore:nocc,ncore:nocc,:]) ppaa = numpy.array(eri[:,:,ncore:nocc,ncore:nocc]) papa = numpy.array(eri[:,ncore:nocc,:,ncore:nocc]) jc_pp = numpy.einsum('iipq->ipq', eri[:ncore,:ncore,:,:]) kc_pp = numpy.einsum('ipqi->ipq', eri[:ncore,:,:,:ncore]) vhf_c = numpy.einsum('cij->ij', jc_pp)*2 - numpy.einsum('cij->ij', kc_pp) j_pc = numpy.einsum('ijj->ji', jc_pp) k_pc = numpy.einsum('ijj->ji', kc_pp) print('vhf_c', numpy.allclose(vhf_c, eris1.vhf_c)) print('j_pc ', numpy.allclose(j_pc, eris1.j_pc)) print('k_pc ', numpy.allclose(k_pc, eris1.k_pc)) print('ppaa ', numpy.allclose(ppaa , eris0.ppaa )) print('papa ', numpy.allclose(papa , eris0.papa ))