1aa34f9e7SJames Wright#!/usr/bin/env python3 2aa34f9e7SJames Wrightfrom junit_xml import TestCase 3aa34f9e7SJames Wrightfrom smartsim import Experiment 4aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings 5aa34f9e7SJames Wrightfrom smartredis import Client 6aa34f9e7SJames Wrightimport numpy as np 7aa34f9e7SJames Wrightfrom pathlib import Path 8aa34f9e7SJames Wrightimport argparse 9aa34f9e7SJames Wrightimport traceback 10aa34f9e7SJames Wrightimport sys 11aa34f9e7SJames Wrightimport time 12aa34f9e7SJames Wrightfrom typing import Tuple 13aa34f9e7SJames Wrightimport os 14aa34f9e7SJames Wrightimport shutil 15aa34f9e7SJames Wrightimport logging 16aa34f9e7SJames Wrightimport socket 17aa34f9e7SJames Wright 18aa34f9e7SJames Wright# autopep8 off 19aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix()) 20aa34f9e7SJames Wright# autopep8 on 21aa34f9e7SJames Wright 22aa34f9e7SJames Wrightlogging.disable(logging.WARNING) 23aa34f9e7SJames Wright 24aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute() 25aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output' 26aa34f9e7SJames Wright 27aa34f9e7SJames Wright 28aa34f9e7SJames Wrightdef getOpenSocket(): 29aa34f9e7SJames Wright s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30aa34f9e7SJames Wright s.bind(('', 0)) 31aa34f9e7SJames Wright addr = s.getsockname() 32aa34f9e7SJames Wright s.close() 33aa34f9e7SJames Wright return addr[1] 34aa34f9e7SJames Wright 35aa34f9e7SJames Wright 36aa34f9e7SJames Wrightclass NoError(Exception): 37aa34f9e7SJames Wright pass 38aa34f9e7SJames Wright 39aa34f9e7SJames Wright 40aa34f9e7SJames Wrightdef assert_np_all(test, truth): 41aa34f9e7SJames Wright """Assert with better error reporting""" 42aa34f9e7SJames Wright try: 43aa34f9e7SJames Wright assert np.all(test == truth) 44aa34f9e7SJames Wright except Exception as e: 45aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 46aa34f9e7SJames Wright 47aa34f9e7SJames Wright 48aa34f9e7SJames Wrightdef assert_equal(test, truth): 49aa34f9e7SJames Wright """Assert with better error reporting""" 50aa34f9e7SJames Wright try: 51aa34f9e7SJames Wright assert test == truth 52aa34f9e7SJames Wright except Exception as e: 53aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 54aa34f9e7SJames Wright 55aa34f9e7SJames Wright 56aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8): 57aa34f9e7SJames Wright """Verify the training data 58aa34f9e7SJames Wright 59aa34f9e7SJames Wright Cannot just use np.allclose due to vorticity vector directionality. 60aa34f9e7SJames Wright Check whether the S-frame-oriented vorticity vector's second component is just flipped. 61aa34f9e7SJames Wright This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector. 62aa34f9e7SJames Wright If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip. 63aa34f9e7SJames Wright This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model. 64aa34f9e7SJames Wright """ 65aa34f9e7SJames Wright if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol): 66aa34f9e7SJames Wright 67aa34f9e7SJames Wright total_tolerances = atol + rtol * np.abs(correct_array) # mimic np.allclose tolerance calculation 68aa34f9e7SJames Wright idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances) 69aa34f9e7SJames Wright if not np.all(idx_notclose[1] == 4): 70aa34f9e7SJames Wright # values other than vorticity are not close 71aa34f9e7SJames Wright test_fail = True 72aa34f9e7SJames Wright else: 73aa34f9e7SJames Wright database_vorticity = database_array[idx_notclose] 74aa34f9e7SJames Wright correct_vorticity = correct_array[idx_notclose] 75aa34f9e7SJames Wright test_fail = False if np.allclose(-database_vorticity, correct_vorticity, 76aa34f9e7SJames Wright atol=atol, rtol=rtol) else True 77aa34f9e7SJames Wright 78aa34f9e7SJames Wright if test_fail: 79aa34f9e7SJames Wright database_output_path = Path( 80aa34f9e7SJames Wright f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute() 81aa34f9e7SJames Wright np.save(database_output_path, database_array) 82aa34f9e7SJames Wright raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n" 83aa34f9e7SJames Wright f"Array saved to {database_output_path.as_posix()}") 84aa34f9e7SJames Wright 85aa34f9e7SJames Wright 86aa34f9e7SJames Wrightclass SmartSimTest(object): 87aa34f9e7SJames Wright 88aa34f9e7SJames Wright def __init__(self, directory_path: Path): 89aa34f9e7SJames Wright self.exp: Experiment 90aa34f9e7SJames Wright self.database = None 91aa34f9e7SJames Wright self.directory_path: Path = directory_path 92aa34f9e7SJames Wright self.original_path: Path 93aa34f9e7SJames Wright 94aa34f9e7SJames Wright def setup(self): 95aa34f9e7SJames Wright """To create the test directory and start SmartRedis database""" 96aa34f9e7SJames Wright self.original_path = Path(os.getcwd()) 97aa34f9e7SJames Wright 98aa34f9e7SJames Wright if self.directory_path.exists() and self.directory_path.is_dir(): 99aa34f9e7SJames Wright shutil.rmtree(self.directory_path) 100aa34f9e7SJames Wright self.directory_path.mkdir() 101aa34f9e7SJames Wright os.chdir(self.directory_path) 102aa34f9e7SJames Wright 103aa34f9e7SJames Wright PORT = getOpenSocket() 104aa34f9e7SJames Wright self.exp = Experiment("test", launcher="local") 105aa34f9e7SJames Wright self.database = self.exp.create_database(port=PORT, batch=False, interface="lo") 106aa34f9e7SJames Wright self.exp.generate(self.database) 107aa34f9e7SJames Wright self.exp.start(self.database) 108aa34f9e7SJames Wright 109aa34f9e7SJames Wright # SmartRedis will complain if these aren't set 110aa34f9e7SJames Wright os.environ['SR_LOG_FILE'] = 'R' 111aa34f9e7SJames Wright os.environ['SR_LOG_LEVEL'] = 'INFO' 112aa34f9e7SJames Wright 113aa34f9e7SJames Wright def test(self, ceed_resource) -> Tuple[bool, Exception, str]: 114aa34f9e7SJames Wright client = None 115aa34f9e7SJames Wright arguments = [] 116aa34f9e7SJames Wright exe_path = "../../build/navierstokes" 117aa34f9e7SJames Wright try: 118aa34f9e7SJames Wright arguments = [ 119aa34f9e7SJames Wright '-ceed', ceed_resource, 120*fc37ad8cSJames Wright '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(), 121aa34f9e7SJames Wright '-ts_max_steps', '2', 122aa34f9e7SJames Wright '-diff_filter_grid_based_width', 123aa34f9e7SJames Wright '-ts_monitor', '-snes_monitor', 124aa34f9e7SJames Wright '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor', 125aa34f9e7SJames Wright '-degree', '1', 126aa34f9e7SJames Wright '-sgs_train_enable', 127aa34f9e7SJames Wright '-sgs_train_write_data_interval', '2', 128aa34f9e7SJames Wright '-sgs_train_filter_width_scales', '1.2,3.1', 129aa34f9e7SJames Wright '-bc_symmetry_z', 130aa34f9e7SJames Wright '-dm_plex_shape', 'zbox', 131aa34f9e7SJames Wright '-dm_plex_box_bd', 'none,none,periodic', 132aa34f9e7SJames Wright '-dm_plex_box_faces', '4,6,1', 133aa34f9e7SJames Wright '-mesh_transform', 134aa34f9e7SJames Wright ] 135aa34f9e7SJames Wright 136aa34f9e7SJames Wright run_settings = RunSettings(exe_path, exe_args=arguments) 137aa34f9e7SJames Wright 138aa34f9e7SJames Wright client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings) 139aa34f9e7SJames Wright 140aa34f9e7SJames Wright # Start the client model 141aa34f9e7SJames Wright self.exp.start(client_exp, summary=False, block=True) 142aa34f9e7SJames Wright 143aa34f9e7SJames Wright client = Client(cluster=False, address=self.database.get_address()[0]) 144aa34f9e7SJames Wright 145aa34f9e7SJames Wright assert client.poll_tensor("sizeInfo", 250, 5) 146aa34f9e7SJames Wright assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0])) 147aa34f9e7SJames Wright 148aa34f9e7SJames Wright assert client.poll_tensor("check-run", 250, 5) 149aa34f9e7SJames Wright assert_equal(client.get_tensor("check-run")[0], 1) 150aa34f9e7SJames Wright 151aa34f9e7SJames Wright assert client.poll_tensor("tensor-ow", 250, 5) 152aa34f9e7SJames Wright assert_equal(client.get_tensor("tensor-ow")[0], 1) 153aa34f9e7SJames Wright 154aa34f9e7SJames Wright assert client.poll_tensor("num_filter_widths", 250, 5) 155aa34f9e7SJames Wright assert_equal(client.get_tensor("num_filter_widths")[0], 2) 156aa34f9e7SJames Wright 157aa34f9e7SJames Wright assert client.poll_tensor("step", 250, 5) 158aa34f9e7SJames Wright assert_equal(client.get_tensor("step")[0], 2) 159aa34f9e7SJames Wright 160aa34f9e7SJames Wright assert client.poll_tensor("y.0.0", 250, 5) 161aa34f9e7SJames Wright test_data_path = test_output_dir / "y00_output.npy" 162aa34f9e7SJames Wright assert test_data_path.is_file() 163aa34f9e7SJames Wright correct_value = np.load(test_data_path) 164aa34f9e7SJames Wright database_value = client.get_tensor("y.0.0") 165aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 166aa34f9e7SJames Wright 167aa34f9e7SJames Wright assert client.poll_tensor("y.0.1", 250, 5) 168aa34f9e7SJames Wright test_data_path = test_output_dir / "y01_output.npy" 169aa34f9e7SJames Wright assert test_data_path.is_file() 170aa34f9e7SJames Wright correct_value = np.load(test_data_path) 171aa34f9e7SJames Wright database_value = client.get_tensor("y.0.1") 172aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 173aa34f9e7SJames Wright 174aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 175aa34f9e7SJames Wright output = (True, NoError(), exe_path + ' ' + ' '.join(arguments)) 176aa34f9e7SJames Wright except Exception as e: 177aa34f9e7SJames Wright output = (False, e, exe_path + ' ' + ' '.join(arguments)) 178aa34f9e7SJames Wright 179aa34f9e7SJames Wright finally: 180aa34f9e7SJames Wright if client: 181aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 182aa34f9e7SJames Wright 183aa34f9e7SJames Wright return output 184aa34f9e7SJames Wright 185aa34f9e7SJames Wright def test_junit(self, ceed_resource): 186aa34f9e7SJames Wright start: float = time.time() 187aa34f9e7SJames Wright 188aa34f9e7SJames Wright passTest, exception, args = self.test(ceed_resource) 189aa34f9e7SJames Wright 190aa34f9e7SJames Wright output = "" if isinstance(exception, NoError) else ''.join( 191aa34f9e7SJames Wright traceback.TracebackException.from_exception(exception).format()) 192aa34f9e7SJames Wright 193aa34f9e7SJames Wright test_case = TestCase(f'SmartSim Test {ceed_resource}', 194aa34f9e7SJames Wright elapsed_sec=time.time() - start, 195aa34f9e7SJames Wright timestamp=time.strftime( 196aa34f9e7SJames Wright '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 197aa34f9e7SJames Wright stdout=output, 198aa34f9e7SJames Wright stderr=output, 199aa34f9e7SJames Wright allow_multiple_subelements=True, 200aa34f9e7SJames Wright category=f'SmartSim Tests') 201aa34f9e7SJames Wright test_case.args = args 202aa34f9e7SJames Wright if not passTest and 'occa' in ceed_resource: 203aa34f9e7SJames Wright test_case.add_skipped_info("OCCA mode not supported") 204aa34f9e7SJames Wright elif not passTest: 205aa34f9e7SJames Wright test_case.add_failure_info("exception", output) 206aa34f9e7SJames Wright 207aa34f9e7SJames Wright return test_case 208aa34f9e7SJames Wright 209aa34f9e7SJames Wright def teardown(self): 210aa34f9e7SJames Wright self.exp.stop(self.database) 211aa34f9e7SJames Wright os.chdir(self.original_path) 212aa34f9e7SJames Wright 213aa34f9e7SJames Wright 214aa34f9e7SJames Wrightif __name__ == "__main__": 215aa34f9e7SJames Wright parser = argparse.ArgumentParser('Testing script for SmartSim integration') 216aa34f9e7SJames Wright parser.add_argument( 217aa34f9e7SJames Wright '-c', 218aa34f9e7SJames Wright '--ceed-backends', 219aa34f9e7SJames Wright type=str, 220aa34f9e7SJames Wright nargs='*', 221aa34f9e7SJames Wright default=['/cpu/self'], 222aa34f9e7SJames Wright help='libCEED backend to use with convergence tests') 223aa34f9e7SJames Wright args = parser.parse_args() 224aa34f9e7SJames Wright 225aa34f9e7SJames Wright test_dir = file_dir / "smartsim_test_dir" 226aa34f9e7SJames Wright print("Setting up database...", end='') 227aa34f9e7SJames Wright test_framework = SmartSimTest(test_dir) 228aa34f9e7SJames Wright test_framework.setup() 229aa34f9e7SJames Wright print(" Done!") 230aa34f9e7SJames Wright for ceed_resource in args.ceed_backends: 231aa34f9e7SJames Wright print("working on " + ceed_resource + ' ...', end='') 232aa34f9e7SJames Wright passTest, exception, _ = test_framework.test(ceed_resource) 233aa34f9e7SJames Wright 234aa34f9e7SJames Wright if passTest: 235aa34f9e7SJames Wright print("Passed!") 236aa34f9e7SJames Wright else: 237aa34f9e7SJames Wright print("Failed!", file=sys.stderr) 238aa34f9e7SJames Wright print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr) 239aa34f9e7SJames Wright 240aa34f9e7SJames Wright print("Cleaning up database...", end='') 241aa34f9e7SJames Wright test_framework.teardown() 242aa34f9e7SJames Wright print(" Done!") 243