#!/usr/bin/env python
# Copyright 2014-2020 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Author: Qiming Sun <osirpt.sun@gmail.com>
#
import numpy
import pyscf.lib.logger as logger
from pyscf.mcscf import mc1step
[docs]
def kernel(casscf, mo_coeff, tol=1e-7, conv_tol_grad=None,
ci0=None, callback=None, verbose=None, dump_chk=True):
from pyscf.mcscf.addons import StateAverageMCSCFSolver
if verbose is None:
verbose = casscf.verbose
if callback is None:
callback = casscf.callback
log = logger.Logger(casscf.stdout, verbose)
cput0 = (logger.process_clock(), logger.perf_counter())
log.debug('Start 2-step CASSCF')
mo = mo_coeff
nmo = mo.shape[1]
ncore = casscf.ncore
ncas = casscf.ncas
nocc = ncore + ncas
eris = casscf.ao2mo(mo)
e_tot, e_cas, fcivec = casscf.casci(mo, ci0, eris, log, locals())
if ncas == nmo and not casscf.internal_rotation:
if casscf.canonicalization:
log.debug('CASSCF canonicalization')
mo, fcivec, mo_energy = casscf.canonicalize(mo, fcivec, eris,
casscf.sorting_mo_energy,
casscf.natorb, verbose=log)
else:
mo_energy = None
return True, e_tot, e_cas, fcivec, mo, mo_energy
if conv_tol_grad is None:
conv_tol_grad = numpy.sqrt(tol)
logger.info(casscf, 'Set conv_tol_grad to %g', conv_tol_grad)
conv_tol_ddm = conv_tol_grad * 3
conv = False
de, elast = e_tot, e_tot
totmicro = totinner = 0
casdm1 = 0
r0 = None
t2m = t1m = log.timer('Initializing 2-step CASSCF', *cput0)
imacro = 0
while not conv and imacro < casscf.max_cycle_macro:
imacro += 1
njk = 0
t3m = t2m
casdm1_old = casdm1
casdm1, casdm2 = casscf.fcisolver.make_rdm12(fcivec, ncas, casscf.nelecas)
norm_ddm = numpy.linalg.norm(casdm1 - casdm1_old)
t3m = log.timer('update CAS DM', *t3m)
max_cycle_micro = 1 # casscf.micro_cycle_scheduler(locals())
max_stepsize = casscf.max_stepsize_scheduler(locals())
for imicro in range(max_cycle_micro):
rota = casscf.rotate_orb_cc(mo, lambda:fcivec, lambda:casdm1, lambda:casdm2,
eris, r0, conv_tol_grad*.3, max_stepsize, log)
u, g_orb, njk1, r0 = next(rota)
rota.close()
njk += njk1
norm_t = numpy.linalg.norm(u-numpy.eye(nmo))
norm_gorb = numpy.linalg.norm(g_orb)
if imicro == 0:
norm_gorb0 = norm_gorb
de = numpy.dot(casscf.pack_uniq_var(u), g_orb)
t3m = log.timer('orbital rotation', *t3m)
eris = None
u = u.copy()
g_orb = g_orb.copy()
mo = casscf.rotate_mo(mo, u, log)
eris = casscf.ao2mo(mo)
t3m = log.timer('update eri', *t3m)
log.debug('micro %2d ~dE=%5.3g |u-1|=%5.3g |g[o]|=%5.3g |dm1|=%5.3g',
imicro, de, norm_t, norm_gorb, norm_ddm)
if callable(callback):
callback(locals())
t2m = log.timer('micro iter %2d'%imicro, *t2m)
if norm_t < 1e-4 or abs(de) < tol*.4 or norm_gorb < conv_tol_grad*.2:
break
totinner += njk
totmicro += imicro + 1
max_offdiag_u = numpy.abs(numpy.triu(u, 1)).max()
if max_offdiag_u < casscf.small_rot_tol:
small_rot = True
log.debug('Small orbital rotation, restart CI if supported by solver')
else:
small_rot = False
if not isinstance(casscf, StateAverageMCSCFSolver):
# The fcivec from builtin FCI solver is a numpy.ndarray
if not isinstance(fcivec, numpy.ndarray):
fcivec = small_rot
else:
newvecs = []
for subvec in fcivec:
if not isinstance(subvec, numpy.ndarray):
newvecs.append(small_rot)
else:
newvecs.append(subvec)
fcivec = newvecs
e_tot, e_cas, fcivec = casscf.casci(mo, fcivec, eris, log, locals())
log.timer('CASCI solver', *t3m)
t2m = t1m = log.timer('macro iter %2d'%imacro, *t1m)
de, elast = e_tot - elast, e_tot
if (abs(de) < tol and
norm_gorb < conv_tol_grad and norm_ddm < conv_tol_ddm and
(max_offdiag_u < casscf.small_rot_tol or casscf.small_rot_tol == 0)):
conv = True
else:
elast = e_tot
if dump_chk:
casscf.dump_chk(locals())
if callable(callback):
callback(locals())
if conv:
log.info('2-step CASSCF converged in %3d macro (%3d JK %3d micro) steps',
imacro, totinner, totmicro)
else:
log.info('2-step CASSCF not converged, %3d macro (%3d JK %3d micro) steps',
imacro, totinner, totmicro)
if casscf.canonicalization:
log.info('CASSCF canonicalization')
mo, fcivec, mo_energy = \
casscf.canonicalize(mo, fcivec, eris, casscf.sorting_mo_energy,
casscf.natorb, casdm1, log)
if casscf.natorb and dump_chk: # dump_chk may save casdm1
occ, ucas = casscf._eig(-casdm1, ncore, nocc)
casdm1 = numpy.diag(-occ)
else:
if casscf.natorb:
# FIXME (pyscf-2.0): Whether to transform natural orbitals in
# active space when this flag is enabled?
log.warn('The attribute natorb of mcscf object affects only the '
'orbital canonicalization.\n'
'If you would like to get natural orbitals in active space '
'without touching core and external orbitals, an explicit '
'call to mc.cas_natorb_() is required')
mo_energy = None
if dump_chk:
casscf.dump_chk(locals())
log.timer('2-step CASSCF', *cput0)
return conv, e_tot, e_cas, fcivec, mo, mo_energy
if __name__ == '__main__':
from pyscf import gto
from pyscf import scf
mol = gto.Mole()
mol.verbose = 0
mol.output = None#"out_h2o"
mol.atom = [
['H', ( 1.,-1. , 0. )],
['H', ( 0.,-1. ,-1. )],
['H', ( 1.,-0.5 ,-1. )],
['H', ( 0.,-0.5 ,-1. )],
['H', ( 0.,-0.5 ,-0. )],
['H', ( 0.,-0. ,-1. )],
['H', ( 1.,-0.5 , 0. )],
['H', ( 0., 1. , 1. )],
]
mol.basis = {'H': 'sto-3g',
'O': '6-31g',}
mol.build()
m = scf.RHF(mol)
ehf = m.scf()
emc = kernel(mc1step.CASSCF(m, 4, 4), m.mo_coeff, verbose=4)[1]
print(ehf, emc, emc-ehf)
print(emc - -3.22013929407)
mol.atom = [
['O', ( 0., 0. , 0. )],
['H', ( 0., -0.757, 0.587)],
['H', ( 0., 0.757 , 0.587)],]
mol.basis = {'H': 'cc-pvdz',
'O': 'cc-pvdz',}
mol.build()
m = scf.RHF(mol)
ehf = m.scf()
mc = mc1step.CASSCF(m, 6, 4)
mc.verbose = 4
mo = m.mo_coeff.copy()
mo[:,2:5] = m.mo_coeff[:,[4,2,3]]
emc = mc.mc2step(mo)[0]
print(ehf, emc, emc-ehf)
#-76.0267656731 -76.0873922924 -0.0606266193028
print(emc - -76.0873923174, emc - -76.0926176464)