1*aa34f9e7SJames Wright#!/usr/bin/env python3 2*aa34f9e7SJames Wrightfrom junit_xml import TestCase 3*aa34f9e7SJames Wrightfrom smartsim import Experiment 4*aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings 5*aa34f9e7SJames Wrightfrom smartredis import Client 6*aa34f9e7SJames Wrightimport numpy as np 7*aa34f9e7SJames Wrightfrom pathlib import Path 8*aa34f9e7SJames Wrightimport argparse 9*aa34f9e7SJames Wrightimport traceback 10*aa34f9e7SJames Wrightimport sys 11*aa34f9e7SJames Wrightimport time 12*aa34f9e7SJames Wrightfrom typing import Tuple 13*aa34f9e7SJames Wrightimport os 14*aa34f9e7SJames Wrightimport shutil 15*aa34f9e7SJames Wrightimport logging 16*aa34f9e7SJames Wrightimport socket 17*aa34f9e7SJames Wright 18*aa34f9e7SJames Wright# autopep8 off 19*aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix()) 20*aa34f9e7SJames Wright# autopep8 on 21*aa34f9e7SJames Wright 22*aa34f9e7SJames Wrightlogging.disable(logging.WARNING) 23*aa34f9e7SJames Wright 24*aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute() 25*aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output' 26*aa34f9e7SJames Wright 27*aa34f9e7SJames Wright 28*aa34f9e7SJames Wrightdef getOpenSocket(): 29*aa34f9e7SJames Wright s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30*aa34f9e7SJames Wright s.bind(('', 0)) 31*aa34f9e7SJames Wright addr = s.getsockname() 32*aa34f9e7SJames Wright s.close() 33*aa34f9e7SJames Wright return addr[1] 34*aa34f9e7SJames Wright 35*aa34f9e7SJames Wright 36*aa34f9e7SJames Wrightclass NoError(Exception): 37*aa34f9e7SJames Wright pass 38*aa34f9e7SJames Wright 39*aa34f9e7SJames Wright 40*aa34f9e7SJames Wrightdef assert_np_all(test, truth): 41*aa34f9e7SJames Wright """Assert with better error reporting""" 42*aa34f9e7SJames Wright try: 43*aa34f9e7SJames Wright assert np.all(test == truth) 44*aa34f9e7SJames Wright except Exception as e: 45*aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 46*aa34f9e7SJames Wright 47*aa34f9e7SJames Wright 48*aa34f9e7SJames Wrightdef assert_equal(test, truth): 49*aa34f9e7SJames Wright """Assert with better error reporting""" 50*aa34f9e7SJames Wright try: 51*aa34f9e7SJames Wright assert test == truth 52*aa34f9e7SJames Wright except Exception as e: 53*aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 54*aa34f9e7SJames Wright 55*aa34f9e7SJames Wright 56*aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8): 57*aa34f9e7SJames Wright """Verify the training data 58*aa34f9e7SJames Wright 59*aa34f9e7SJames Wright Cannot just use np.allclose due to vorticity vector directionality. 60*aa34f9e7SJames Wright Check whether the S-frame-oriented vorticity vector's second component is just flipped. 61*aa34f9e7SJames Wright This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector. 62*aa34f9e7SJames Wright If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip. 63*aa34f9e7SJames Wright This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model. 64*aa34f9e7SJames Wright """ 65*aa34f9e7SJames Wright if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol): 66*aa34f9e7SJames Wright 67*aa34f9e7SJames Wright total_tolerances = atol + rtol * np.abs(correct_array) # mimic np.allclose tolerance calculation 68*aa34f9e7SJames Wright idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances) 69*aa34f9e7SJames Wright if not np.all(idx_notclose[1] == 4): 70*aa34f9e7SJames Wright # values other than vorticity are not close 71*aa34f9e7SJames Wright test_fail = True 72*aa34f9e7SJames Wright else: 73*aa34f9e7SJames Wright database_vorticity = database_array[idx_notclose] 74*aa34f9e7SJames Wright correct_vorticity = correct_array[idx_notclose] 75*aa34f9e7SJames Wright test_fail = False if np.allclose(-database_vorticity, correct_vorticity, 76*aa34f9e7SJames Wright atol=atol, rtol=rtol) else True 77*aa34f9e7SJames Wright 78*aa34f9e7SJames Wright if test_fail: 79*aa34f9e7SJames Wright database_output_path = Path( 80*aa34f9e7SJames Wright f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute() 81*aa34f9e7SJames Wright np.save(database_output_path, database_array) 82*aa34f9e7SJames Wright raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n" 83*aa34f9e7SJames Wright f"Array saved to {database_output_path.as_posix()}") 84*aa34f9e7SJames Wright 85*aa34f9e7SJames Wright 86*aa34f9e7SJames Wrightclass SmartSimTest(object): 87*aa34f9e7SJames Wright 88*aa34f9e7SJames Wright def __init__(self, directory_path: Path): 89*aa34f9e7SJames Wright self.exp: Experiment 90*aa34f9e7SJames Wright self.database = None 91*aa34f9e7SJames Wright self.directory_path: Path = directory_path 92*aa34f9e7SJames Wright self.original_path: Path 93*aa34f9e7SJames Wright 94*aa34f9e7SJames Wright def setup(self): 95*aa34f9e7SJames Wright """To create the test directory and start SmartRedis database""" 96*aa34f9e7SJames Wright self.original_path = Path(os.getcwd()) 97*aa34f9e7SJames Wright 98*aa34f9e7SJames Wright if self.directory_path.exists() and self.directory_path.is_dir(): 99*aa34f9e7SJames Wright shutil.rmtree(self.directory_path) 100*aa34f9e7SJames Wright self.directory_path.mkdir() 101*aa34f9e7SJames Wright os.chdir(self.directory_path) 102*aa34f9e7SJames Wright 103*aa34f9e7SJames Wright PORT = getOpenSocket() 104*aa34f9e7SJames Wright self.exp = Experiment("test", launcher="local") 105*aa34f9e7SJames Wright self.database = self.exp.create_database(port=PORT, batch=False, interface="lo") 106*aa34f9e7SJames Wright self.exp.generate(self.database) 107*aa34f9e7SJames Wright self.exp.start(self.database) 108*aa34f9e7SJames Wright 109*aa34f9e7SJames Wright # SmartRedis will complain if these aren't set 110*aa34f9e7SJames Wright os.environ['SR_LOG_FILE'] = 'R' 111*aa34f9e7SJames Wright os.environ['SR_LOG_LEVEL'] = 'INFO' 112*aa34f9e7SJames Wright 113*aa34f9e7SJames Wright def test(self, ceed_resource) -> Tuple[bool, Exception, str]: 114*aa34f9e7SJames Wright client = None 115*aa34f9e7SJames Wright arguments = [] 116*aa34f9e7SJames Wright exe_path = "../../build/navierstokes" 117*aa34f9e7SJames Wright try: 118*aa34f9e7SJames Wright arguments = [ 119*aa34f9e7SJames Wright '-ceed', ceed_resource, 120*aa34f9e7SJames Wright '-options_file', (file_dir / '../blasius.yaml').as_posix(), 121*aa34f9e7SJames Wright '-ts_max_steps', '2', 122*aa34f9e7SJames Wright '-diff_filter_grid_based_width', 123*aa34f9e7SJames Wright '-ts_monitor', '-snes_monitor', 124*aa34f9e7SJames Wright '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor', 125*aa34f9e7SJames Wright '-degree', '1', 126*aa34f9e7SJames Wright '-sgs_train_enable', 127*aa34f9e7SJames Wright '-sgs_train_write_data_interval', '2', 128*aa34f9e7SJames Wright '-sgs_train_filter_width_scales', '1.2,3.1', 129*aa34f9e7SJames Wright '-bc_symmetry_z', 130*aa34f9e7SJames Wright '-dm_plex_shape', 'zbox', 131*aa34f9e7SJames Wright '-dm_plex_box_bd', 'none,none,periodic', 132*aa34f9e7SJames Wright '-dm_plex_box_faces', '4,6,1', 133*aa34f9e7SJames Wright '-mesh_transform', 134*aa34f9e7SJames Wright ] 135*aa34f9e7SJames Wright 136*aa34f9e7SJames Wright run_settings = RunSettings(exe_path, exe_args=arguments) 137*aa34f9e7SJames Wright 138*aa34f9e7SJames Wright client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings) 139*aa34f9e7SJames Wright 140*aa34f9e7SJames Wright # Start the client model 141*aa34f9e7SJames Wright self.exp.start(client_exp, summary=False, block=True) 142*aa34f9e7SJames Wright 143*aa34f9e7SJames Wright client = Client(cluster=False, address=self.database.get_address()[0]) 144*aa34f9e7SJames Wright 145*aa34f9e7SJames Wright assert client.poll_tensor("sizeInfo", 250, 5) 146*aa34f9e7SJames Wright assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0])) 147*aa34f9e7SJames Wright 148*aa34f9e7SJames Wright assert client.poll_tensor("check-run", 250, 5) 149*aa34f9e7SJames Wright assert_equal(client.get_tensor("check-run")[0], 1) 150*aa34f9e7SJames Wright 151*aa34f9e7SJames Wright assert client.poll_tensor("tensor-ow", 250, 5) 152*aa34f9e7SJames Wright assert_equal(client.get_tensor("tensor-ow")[0], 1) 153*aa34f9e7SJames Wright 154*aa34f9e7SJames Wright assert client.poll_tensor("num_filter_widths", 250, 5) 155*aa34f9e7SJames Wright assert_equal(client.get_tensor("num_filter_widths")[0], 2) 156*aa34f9e7SJames Wright 157*aa34f9e7SJames Wright assert client.poll_tensor("step", 250, 5) 158*aa34f9e7SJames Wright assert_equal(client.get_tensor("step")[0], 2) 159*aa34f9e7SJames Wright 160*aa34f9e7SJames Wright assert client.poll_tensor("y.0.0", 250, 5) 161*aa34f9e7SJames Wright test_data_path = test_output_dir / "y00_output.npy" 162*aa34f9e7SJames Wright assert test_data_path.is_file() 163*aa34f9e7SJames Wright correct_value = np.load(test_data_path) 164*aa34f9e7SJames Wright database_value = client.get_tensor("y.0.0") 165*aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 166*aa34f9e7SJames Wright 167*aa34f9e7SJames Wright assert client.poll_tensor("y.0.1", 250, 5) 168*aa34f9e7SJames Wright test_data_path = test_output_dir / "y01_output.npy" 169*aa34f9e7SJames Wright assert test_data_path.is_file() 170*aa34f9e7SJames Wright correct_value = np.load(test_data_path) 171*aa34f9e7SJames Wright database_value = client.get_tensor("y.0.1") 172*aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 173*aa34f9e7SJames Wright 174*aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 175*aa34f9e7SJames Wright output = (True, NoError(), exe_path + ' ' + ' '.join(arguments)) 176*aa34f9e7SJames Wright except Exception as e: 177*aa34f9e7SJames Wright output = (False, e, exe_path + ' ' + ' '.join(arguments)) 178*aa34f9e7SJames Wright 179*aa34f9e7SJames Wright finally: 180*aa34f9e7SJames Wright if client: 181*aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 182*aa34f9e7SJames Wright 183*aa34f9e7SJames Wright return output 184*aa34f9e7SJames Wright 185*aa34f9e7SJames Wright def test_junit(self, ceed_resource): 186*aa34f9e7SJames Wright start: float = time.time() 187*aa34f9e7SJames Wright 188*aa34f9e7SJames Wright passTest, exception, args = self.test(ceed_resource) 189*aa34f9e7SJames Wright 190*aa34f9e7SJames Wright output = "" if isinstance(exception, NoError) else ''.join( 191*aa34f9e7SJames Wright traceback.TracebackException.from_exception(exception).format()) 192*aa34f9e7SJames Wright 193*aa34f9e7SJames Wright test_case = TestCase(f'SmartSim Test {ceed_resource}', 194*aa34f9e7SJames Wright elapsed_sec=time.time() - start, 195*aa34f9e7SJames Wright timestamp=time.strftime( 196*aa34f9e7SJames Wright '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 197*aa34f9e7SJames Wright stdout=output, 198*aa34f9e7SJames Wright stderr=output, 199*aa34f9e7SJames Wright allow_multiple_subelements=True, 200*aa34f9e7SJames Wright category=f'SmartSim Tests') 201*aa34f9e7SJames Wright test_case.args = args 202*aa34f9e7SJames Wright if not passTest and 'occa' in ceed_resource: 203*aa34f9e7SJames Wright test_case.add_skipped_info("OCCA mode not supported") 204*aa34f9e7SJames Wright elif not passTest: 205*aa34f9e7SJames Wright test_case.add_failure_info("exception", output) 206*aa34f9e7SJames Wright 207*aa34f9e7SJames Wright return test_case 208*aa34f9e7SJames Wright 209*aa34f9e7SJames Wright def teardown(self): 210*aa34f9e7SJames Wright self.exp.stop(self.database) 211*aa34f9e7SJames Wright os.chdir(self.original_path) 212*aa34f9e7SJames Wright 213*aa34f9e7SJames Wright 214*aa34f9e7SJames Wrightif __name__ == "__main__": 215*aa34f9e7SJames Wright parser = argparse.ArgumentParser('Testing script for SmartSim integration') 216*aa34f9e7SJames Wright parser.add_argument( 217*aa34f9e7SJames Wright '-c', 218*aa34f9e7SJames Wright '--ceed-backends', 219*aa34f9e7SJames Wright type=str, 220*aa34f9e7SJames Wright nargs='*', 221*aa34f9e7SJames Wright default=['/cpu/self'], 222*aa34f9e7SJames Wright help='libCEED backend to use with convergence tests') 223*aa34f9e7SJames Wright args = parser.parse_args() 224*aa34f9e7SJames Wright 225*aa34f9e7SJames Wright test_dir = file_dir / "smartsim_test_dir" 226*aa34f9e7SJames Wright print("Setting up database...", end='') 227*aa34f9e7SJames Wright test_framework = SmartSimTest(test_dir) 228*aa34f9e7SJames Wright test_framework.setup() 229*aa34f9e7SJames Wright print(" Done!") 230*aa34f9e7SJames Wright for ceed_resource in args.ceed_backends: 231*aa34f9e7SJames Wright print("working on " + ceed_resource + ' ...', end='') 232*aa34f9e7SJames Wright passTest, exception, _ = test_framework.test(ceed_resource) 233*aa34f9e7SJames Wright 234*aa34f9e7SJames Wright if passTest: 235*aa34f9e7SJames Wright print("Passed!") 236*aa34f9e7SJames Wright else: 237*aa34f9e7SJames Wright print("Failed!", file=sys.stderr) 238*aa34f9e7SJames Wright print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr) 239*aa34f9e7SJames Wright 240*aa34f9e7SJames Wright print("Cleaning up database...", end='') 241*aa34f9e7SJames Wright test_framework.teardown() 242*aa34f9e7SJames Wright print(" Done!") 243