Source code for pyscf.pbc.tools.tril

# Copyright 2014-2018 The PySCF Developers. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy
import pyscf.lib

[docs] def tril_index(ki,kj): assert (numpy.array([ki<=kj])).all() return (kj*(kj+1))//2 + ki
# TODO: fairly messy and slow
[docs] def unpack_tril(in_array,nkpts,kp,kq,kr,ks): # We are only dealing with the case that one of kp,kq,kr is a list # if in_array.shape[0] == nkpts: return in_array[kp,kq,kr].copy() nints = sum([isinstance(x, (int, numpy.integer)) for x in (kp,kq,kr)]) assert (nints>=2) kp,kq,kr,ks = [[x] if isinstance(x, (int, numpy.integer)) else x for x in (kp,kq,kr,ks)] kp,kq,kr,ks = [numpy.array(x) for x in (kp,kq,kr,ks)] indices = numpy.array(pyscf.lib.cartesian_prod((kp,kq,kr))) not_tril = numpy.array([x[0]>x[1] for x in indices]) if sum(not_tril) == 0: # check if lower triangular indices = [tril_index(indices[:,0],indices[:,1]),indices[:,2]] tmp = in_array[indices].copy() if nints == 3: tmp = tmp.reshape(in_array.shape[2:7]) return tmp # If not lower triangular, permute the k-point indices indices[not_tril] = indices[not_tril][:,[1,0,2]] indices[:,2][not_tril] = ks[not_tril] indices = [tril_index(indices[:,0],indices[:,1]),indices[:,2]] tmp = in_array[indices].copy() tmp[not_tril] = tmp[not_tril].transpose(0,2,1,4,3) if nints == 3: tmp = tmp.reshape(in_array.shape[2:7]) return tmp