15808f684SSatish Balayfrom petsc4py import PETSc 25808f684SSatish Balayimport unittest 35808f684SSatish Balay 45808f684SSatish Balay# -------------------------------------------------------------------- 55808f684SSatish Balay 65808f684SSatish Balayclass BaseTestLGMap(object): 75808f684SSatish Balay 85808f684SSatish Balay def _mk_idx(self, comm): 95808f684SSatish Balay comm_size = comm.getSize() 105808f684SSatish Balay comm_rank = comm.getRank() 115808f684SSatish Balay lsize = 10 125808f684SSatish Balay first = lsize * comm_rank 135808f684SSatish Balay last = first + lsize 145808f684SSatish Balay if comm_rank > 0: 155808f684SSatish Balay first -= 1 165808f684SSatish Balay if comm_rank < (comm_size-1): 175808f684SSatish Balay last += 1 185808f684SSatish Balay return list(range(first, last)) 195808f684SSatish Balay 205808f684SSatish Balay def tearDown(self): 215808f684SSatish Balay self.lgmap = None 22*62e5d2d2SJDBetteridge PETSc.garbage_cleanup() 235808f684SSatish Balay 245808f684SSatish Balay def testGetSize(self): 255808f684SSatish Balay size = self.lgmap.getSize() 265808f684SSatish Balay self.assertTrue(size >= 0) 275808f684SSatish Balay 285808f684SSatish Balay def testGetIndices(self): 295808f684SSatish Balay size = self.lgmap.getSize() 305808f684SSatish Balay idx = self.lgmap.getIndices() 315808f684SSatish Balay self.assertEqual(len(idx), size) 325808f684SSatish Balay for i, val in enumerate(self.idx): 335808f684SSatish Balay self.assertEqual(idx[i], val) 345808f684SSatish Balay 355808f684SSatish Balay def testGetInfo(self): 365808f684SSatish Balay info = self.lgmap.getInfo() 375808f684SSatish Balay self.assertEqual(type(info), dict) 385808f684SSatish Balay if self.lgmap.getComm().getSize() == 1: 395808f684SSatish Balay self.assertEqual(info, {}) 405808f684SSatish Balay else: 415808f684SSatish Balay self.assertTrue(len(info) > 1) 425808f684SSatish Balay self.assertTrue(len(info) < 4) 435808f684SSatish Balay 445808f684SSatish Balay def testApply(self): 455808f684SSatish Balay idxin = list(range(self.lgmap.getSize())) 465808f684SSatish Balay idxout = self.lgmap.apply(idxin) 475808f684SSatish Balay self.lgmap.apply(idxin, idxout) 485808f684SSatish Balay invmap = self.lgmap.applyInverse(idxout) 495808f684SSatish Balay 505808f684SSatish Balay 515808f684SSatish Balay def testApplyIS(self): 525808f684SSatish Balay is_in = PETSc.IS().createStride(self.lgmap.getSize()) 535808f684SSatish Balay is_out = self.lgmap.apply(is_in) 545808f684SSatish Balay 555808f684SSatish Balay def testProperties(self): 565808f684SSatish Balay for prop in ('size', 'indices', 'info'): 575808f684SSatish Balay self.assertTrue(hasattr(self.lgmap, prop)) 585808f684SSatish Balay 595808f684SSatish Balay# -------------------------------------------------------------------- 605808f684SSatish Balay 615808f684SSatish Balayclass TestLGMap(BaseTestLGMap, unittest.TestCase): 625808f684SSatish Balay 635808f684SSatish Balay def setUp(self): 645808f684SSatish Balay self.idx = self._mk_idx(PETSc.COMM_WORLD) 655808f684SSatish Balay self.lgmap = PETSc.LGMap().create(self.idx, comm=PETSc.COMM_WORLD) 665808f684SSatish Balay 675808f684SSatish Balayclass TestLGMapIS(BaseTestLGMap, unittest.TestCase): 685808f684SSatish Balay 695808f684SSatish Balay def setUp(self): 705808f684SSatish Balay self.idx = self._mk_idx(PETSc.COMM_WORLD) 715808f684SSatish Balay self.iset = PETSc.IS().createGeneral(self.idx, comm=PETSc.COMM_WORLD) 725808f684SSatish Balay self.lgmap = PETSc.LGMap().create(self.iset) 735808f684SSatish Balay 745808f684SSatish Balay def tearDown(self): 755808f684SSatish Balay self.iset = None 765808f684SSatish Balay self.lgmap = None 775808f684SSatish Balay 785808f684SSatish Balay def testSameComm(self): 795808f684SSatish Balay comm1 = self.lgmap.getComm() 805808f684SSatish Balay comm2 = self.iset.getComm() 815808f684SSatish Balay self.assertEqual(comm1, comm2) 825808f684SSatish Balay 835808f684SSatish Balay# -------------------------------------------------------------------- 845808f684SSatish Balay 855808f684SSatish Balayclass TestLGMapBlock(unittest.TestCase): 865808f684SSatish Balay 875808f684SSatish Balay BS = 3 885808f684SSatish Balay 895808f684SSatish Balay def setUp(self): 905808f684SSatish Balay comm = PETSc.COMM_WORLD 915808f684SSatish Balay comm_size = comm.getSize() 925808f684SSatish Balay comm_rank = comm.getRank() 935808f684SSatish Balay lsize = 10 945808f684SSatish Balay first = lsize * comm_rank 955808f684SSatish Balay last = first + lsize 965808f684SSatish Balay if comm_rank > 0: 975808f684SSatish Balay first -= 1 985808f684SSatish Balay if comm_rank < (comm_size-1): 995808f684SSatish Balay last += 1 1005808f684SSatish Balay self.idx = list(range(first, last)) 1015808f684SSatish Balay bs = self.BS 1025808f684SSatish Balay self.lgmap = PETSc.LGMap().create(self.idx, bs, comm=PETSc.COMM_WORLD) 1035808f684SSatish Balay 1045808f684SSatish Balay def tearDown(self): 1055808f684SSatish Balay self.lgmap = None 1065808f684SSatish Balay 1075808f684SSatish Balay def testGetSize(self): 1085808f684SSatish Balay size = self.lgmap.getSize() 1095808f684SSatish Balay self.assertTrue(size >= 0) 1105808f684SSatish Balay 1115808f684SSatish Balay def testGetBlockSize(self): 1125808f684SSatish Balay bs = self.lgmap.getBlockSize() 1135808f684SSatish Balay self.assertEqual(bs, self.BS) 1145808f684SSatish Balay 1155808f684SSatish Balay def testGetBlockIndices(self): 1165808f684SSatish Balay size = self.lgmap.getSize() 1175808f684SSatish Balay bs = self.lgmap.getBlockSize() 1185808f684SSatish Balay idx = self.lgmap.getBlockIndices() 1195808f684SSatish Balay self.assertEqual(len(idx), size//bs) 1205808f684SSatish Balay for i, val in enumerate(self.idx): 1215808f684SSatish Balay self.assertEqual(idx[i], val) 1225808f684SSatish Balay 1235808f684SSatish Balay def testGetIndices(self): 1245808f684SSatish Balay size = self.lgmap.getSize() 1255808f684SSatish Balay bs = self.lgmap.getBlockSize() 1265808f684SSatish Balay idx = self.lgmap.getIndices() 1275808f684SSatish Balay self.assertEqual(len(idx), size) 1285808f684SSatish Balay for i, val in enumerate(self.idx): 1295808f684SSatish Balay for j in range(bs): 1305808f684SSatish Balay self.assertEqual(idx[i*bs+j], val*bs+j) 1315808f684SSatish Balay 1325808f684SSatish Balay def testGetBlockInfo(self): 1335808f684SSatish Balay info = self.lgmap.getBlockInfo() 1345808f684SSatish Balay self.assertEqual(type(info), dict) 1355808f684SSatish Balay if self.lgmap.getComm().getSize() == 1: 1365808f684SSatish Balay self.assertEqual(info, {}) 1375808f684SSatish Balay else: 1385808f684SSatish Balay self.assertTrue(len(info) > 1) 1395808f684SSatish Balay self.assertTrue(len(info) < 4) 1405808f684SSatish Balay 1415808f684SSatish Balay def testGetInfo(self): 1425808f684SSatish Balay info = self.lgmap.getInfo() 1435808f684SSatish Balay self.assertEqual(type(info), dict) 1445808f684SSatish Balay if self.lgmap.getComm().getSize() == 1: 1455808f684SSatish Balay self.assertEqual(info, {}) 1465808f684SSatish Balay else: 1475808f684SSatish Balay self.assertTrue(len(info) > 1) 1485808f684SSatish Balay self.assertTrue(len(info) < 4) 1495808f684SSatish Balay 1505808f684SSatish Balay# -------------------------------------------------------------------- 1515808f684SSatish Balay 1525808f684SSatish Balayif __name__ == '__main__': 1535808f684SSatish Balay unittest.main() 154