1aa34f9e7SJames Wright#!/usr/bin/env python3 2aa34f9e7SJames Wrightfrom junit_xml import TestCase 3aa34f9e7SJames Wrightfrom smartsim import Experiment 4aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings 5aa34f9e7SJames Wrightfrom smartredis import Client 6aa34f9e7SJames Wrightimport numpy as np 7aa34f9e7SJames Wrightfrom pathlib import Path 8aa34f9e7SJames Wrightimport argparse 9aa34f9e7SJames Wrightimport traceback 10aa34f9e7SJames Wrightimport sys 11aa34f9e7SJames Wrightimport time 12aa34f9e7SJames Wrightfrom typing import Tuple 13aa34f9e7SJames Wrightimport os 14aa34f9e7SJames Wrightimport shutil 15aa34f9e7SJames Wrightimport logging 16aa34f9e7SJames Wrightimport socket 17aa34f9e7SJames Wright 18aa34f9e7SJames Wright# autopep8 off 19aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix()) 20aa34f9e7SJames Wright# autopep8 on 21aa34f9e7SJames Wright 22aa34f9e7SJames Wrightlogging.disable(logging.WARNING) 23aa34f9e7SJames Wright 24aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute() 25aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output' 26aa34f9e7SJames Wright 27aa34f9e7SJames Wright 28aa34f9e7SJames Wrightdef getOpenSocket(): 29aa34f9e7SJames Wright s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 30aa34f9e7SJames Wright s.bind(('', 0)) 31aa34f9e7SJames Wright addr = s.getsockname() 32aa34f9e7SJames Wright s.close() 33aa34f9e7SJames Wright return addr[1] 34aa34f9e7SJames Wright 35aa34f9e7SJames Wright 36aa34f9e7SJames Wrightclass NoError(Exception): 37aa34f9e7SJames Wright pass 38aa34f9e7SJames Wright 39aa34f9e7SJames Wright 40aa34f9e7SJames Wrightdef assert_np_all(test, truth): 41aa34f9e7SJames Wright """Assert with better error reporting""" 42aa34f9e7SJames Wright try: 43aa34f9e7SJames Wright assert np.all(test == truth) 44aa34f9e7SJames Wright except Exception as e: 45aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 46aa34f9e7SJames Wright 47aa34f9e7SJames Wright 48aa34f9e7SJames Wrightdef assert_equal(test, truth): 49aa34f9e7SJames Wright """Assert with better error reporting""" 50aa34f9e7SJames Wright try: 51aa34f9e7SJames Wright assert test == truth 52aa34f9e7SJames Wright except Exception as e: 53aa34f9e7SJames Wright raise Exception(f"Expected {truth}, but got {test}") from e 54aa34f9e7SJames Wright 55aa34f9e7SJames Wright 56aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8): 57aa34f9e7SJames Wright """Verify the training data 58aa34f9e7SJames Wright 59aa34f9e7SJames Wright Cannot just use np.allclose due to vorticity vector directionality. 60aa34f9e7SJames Wright Check whether the S-frame-oriented vorticity vector's second component is just flipped. 61aa34f9e7SJames Wright This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector. 62aa34f9e7SJames Wright If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip. 63aa34f9e7SJames Wright This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model. 64aa34f9e7SJames Wright """ 65aa34f9e7SJames Wright if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol): 66aa34f9e7SJames Wright 67aa34f9e7SJames Wright total_tolerances = atol + rtol * np.abs(correct_array) # mimic np.allclose tolerance calculation 68aa34f9e7SJames Wright idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances) 69aa34f9e7SJames Wright if not np.all(idx_notclose[1] == 4): 70aa34f9e7SJames Wright # values other than vorticity are not close 71aa34f9e7SJames Wright test_fail = True 72aa34f9e7SJames Wright else: 73aa34f9e7SJames Wright database_vorticity = database_array[idx_notclose] 74aa34f9e7SJames Wright correct_vorticity = correct_array[idx_notclose] 75aa34f9e7SJames Wright test_fail = False if np.allclose(-database_vorticity, correct_vorticity, 76aa34f9e7SJames Wright atol=atol, rtol=rtol) else True 77aa34f9e7SJames Wright 78aa34f9e7SJames Wright if test_fail: 79aa34f9e7SJames Wright database_output_path = Path( 80aa34f9e7SJames Wright f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute() 81aa34f9e7SJames Wright np.save(database_output_path, database_array) 82aa34f9e7SJames Wright raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n" 83aa34f9e7SJames Wright f"Array saved to {database_output_path.as_posix()}") 84aa34f9e7SJames Wright 85aa34f9e7SJames Wright 86aa34f9e7SJames Wrightclass SmartSimTest(object): 87aa34f9e7SJames Wright 88aa34f9e7SJames Wright def __init__(self, directory_path: Path): 89aa34f9e7SJames Wright self.exp: Experiment 90aa34f9e7SJames Wright self.database = None 91aa34f9e7SJames Wright self.directory_path: Path = directory_path 92aa34f9e7SJames Wright self.original_path: Path 93aa34f9e7SJames Wright 94aa34f9e7SJames Wright def setup(self): 95aa34f9e7SJames Wright """To create the test directory and start SmartRedis database""" 96aa34f9e7SJames Wright self.original_path = Path(os.getcwd()) 97aa34f9e7SJames Wright 98aa34f9e7SJames Wright if self.directory_path.exists() and self.directory_path.is_dir(): 99aa34f9e7SJames Wright shutil.rmtree(self.directory_path) 100aa34f9e7SJames Wright self.directory_path.mkdir() 101aa34f9e7SJames Wright os.chdir(self.directory_path) 102aa34f9e7SJames Wright 103aa34f9e7SJames Wright PORT = getOpenSocket() 104aa34f9e7SJames Wright self.exp = Experiment("test", launcher="local") 105aa34f9e7SJames Wright self.database = self.exp.create_database(port=PORT, batch=False, interface="lo") 106aa34f9e7SJames Wright self.exp.generate(self.database) 107aa34f9e7SJames Wright self.exp.start(self.database) 108aa34f9e7SJames Wright 109aa34f9e7SJames Wright # SmartRedis will complain if these aren't set 110aa34f9e7SJames Wright os.environ['SR_LOG_FILE'] = 'R' 111aa34f9e7SJames Wright os.environ['SR_LOG_LEVEL'] = 'INFO' 112aa34f9e7SJames Wright 113aa34f9e7SJames Wright def test(self, ceed_resource) -> Tuple[bool, Exception, str]: 114aa34f9e7SJames Wright client = None 115aa34f9e7SJames Wright arguments = [] 116aa34f9e7SJames Wright exe_path = "../../build/navierstokes" 117aa34f9e7SJames Wright try: 118aa34f9e7SJames Wright arguments = [ 119aa34f9e7SJames Wright '-ceed', ceed_resource, 120fc37ad8cSJames Wright '-options_file', (file_dir / '../examples/blasius.yaml').as_posix(), 121aa34f9e7SJames Wright '-ts_max_steps', '2', 122aa34f9e7SJames Wright '-diff_filter_grid_based_width', 123*c4020f1fSJames Wright '-ts_monitor_wall_clock_time', '-snes_monitor', '-ts_view_pre', 124aa34f9e7SJames Wright '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor', 125aa34f9e7SJames Wright '-degree', '1', 126aa34f9e7SJames Wright '-sgs_train_enable', 127aa34f9e7SJames Wright '-sgs_train_write_data_interval', '2', 128aa34f9e7SJames Wright '-sgs_train_filter_width_scales', '1.2,3.1', 129aa34f9e7SJames Wright '-bc_symmetry_z', 130aa34f9e7SJames Wright '-dm_plex_shape', 'zbox', 131aa34f9e7SJames Wright '-dm_plex_box_bd', 'none,none,periodic', 132aa34f9e7SJames Wright '-dm_plex_box_faces', '4,6,1', 133*c4020f1fSJames Wright '-mesh_transform', '-ts_monitor_smartsim_solution', 134aa34f9e7SJames Wright ] 135aa34f9e7SJames Wright 136aa34f9e7SJames Wright run_settings = RunSettings(exe_path, exe_args=arguments) 137aa34f9e7SJames Wright 138aa34f9e7SJames Wright client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings) 139aa34f9e7SJames Wright 140aa34f9e7SJames Wright # Start the client model 141aa34f9e7SJames Wright self.exp.start(client_exp, summary=False, block=True) 142aa34f9e7SJames Wright 143aa34f9e7SJames Wright client = Client(cluster=False, address=self.database.get_address()[0]) 144aa34f9e7SJames Wright 145aa34f9e7SJames Wright assert client.poll_tensor("sizeInfo", 250, 5) 146aa34f9e7SJames Wright assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0])) 147aa34f9e7SJames Wright 148aa34f9e7SJames Wright assert client.poll_tensor("check-run", 250, 5) 149aa34f9e7SJames Wright assert_equal(client.get_tensor("check-run")[0], 1) 150aa34f9e7SJames Wright 151aa34f9e7SJames Wright assert client.poll_tensor("tensor-ow", 250, 5) 152aa34f9e7SJames Wright assert_equal(client.get_tensor("tensor-ow")[0], 1) 153aa34f9e7SJames Wright 154aa34f9e7SJames Wright assert client.poll_tensor("num_filter_widths", 250, 5) 155aa34f9e7SJames Wright assert_equal(client.get_tensor("num_filter_widths")[0], 2) 156aa34f9e7SJames Wright 157*c4020f1fSJames Wright assert client.poll_tensor("step", 250, 10) 158aa34f9e7SJames Wright assert_equal(client.get_tensor("step")[0], 2) 159aa34f9e7SJames Wright 160*c4020f1fSJames Wright assert client.poll_tensor("y.0.flow_solution", 250, 5) 161*c4020f1fSJames Wright test_data_path = test_output_dir / "y0flow_solution_output.npy" 162*c4020f1fSJames Wright assert test_data_path.is_file() 163*c4020f1fSJames Wright correct_value = np.load(test_data_path) 164*c4020f1fSJames Wright database_value = client.get_tensor("y.0.flow_solution") 165*c4020f1fSJames Wright verify_training_data(database_value, correct_value, ceed_resource) 166*c4020f1fSJames Wright 167aa34f9e7SJames Wright assert client.poll_tensor("y.0.0", 250, 5) 168aa34f9e7SJames Wright test_data_path = test_output_dir / "y00_output.npy" 169aa34f9e7SJames Wright assert test_data_path.is_file() 170aa34f9e7SJames Wright correct_value = np.load(test_data_path) 171aa34f9e7SJames Wright database_value = client.get_tensor("y.0.0") 172aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 173aa34f9e7SJames Wright 174aa34f9e7SJames Wright assert client.poll_tensor("y.0.1", 250, 5) 175aa34f9e7SJames Wright test_data_path = test_output_dir / "y01_output.npy" 176aa34f9e7SJames Wright assert test_data_path.is_file() 177aa34f9e7SJames Wright correct_value = np.load(test_data_path) 178aa34f9e7SJames Wright database_value = client.get_tensor("y.0.1") 179aa34f9e7SJames Wright verify_training_data(database_value, correct_value, ceed_resource) 180aa34f9e7SJames Wright 181aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 182aa34f9e7SJames Wright output = (True, NoError(), exe_path + ' ' + ' '.join(arguments)) 183aa34f9e7SJames Wright except Exception as e: 184aa34f9e7SJames Wright output = (False, e, exe_path + ' ' + ' '.join(arguments)) 185aa34f9e7SJames Wright 186aa34f9e7SJames Wright finally: 187aa34f9e7SJames Wright if client: 188aa34f9e7SJames Wright client.flush_db([os.environ["SSDB"]]) 189aa34f9e7SJames Wright 190aa34f9e7SJames Wright return output 191aa34f9e7SJames Wright 192aa34f9e7SJames Wright def test_junit(self, ceed_resource): 193aa34f9e7SJames Wright start: float = time.time() 194aa34f9e7SJames Wright 195aa34f9e7SJames Wright passTest, exception, args = self.test(ceed_resource) 196aa34f9e7SJames Wright 197aa34f9e7SJames Wright output = "" if isinstance(exception, NoError) else ''.join( 198aa34f9e7SJames Wright traceback.TracebackException.from_exception(exception).format()) 199aa34f9e7SJames Wright 200aa34f9e7SJames Wright test_case = TestCase(f'SmartSim Test {ceed_resource}', 201aa34f9e7SJames Wright elapsed_sec=time.time() - start, 202aa34f9e7SJames Wright timestamp=time.strftime( 203aa34f9e7SJames Wright '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)), 204aa34f9e7SJames Wright stdout=output, 205aa34f9e7SJames Wright stderr=output, 206aa34f9e7SJames Wright allow_multiple_subelements=True, 207aa34f9e7SJames Wright category=f'SmartSim Tests') 208aa34f9e7SJames Wright test_case.args = args 209aa34f9e7SJames Wright if not passTest and 'occa' in ceed_resource: 210aa34f9e7SJames Wright test_case.add_skipped_info("OCCA mode not supported") 211aa34f9e7SJames Wright elif not passTest: 212aa34f9e7SJames Wright test_case.add_failure_info("exception", output) 213aa34f9e7SJames Wright 214aa34f9e7SJames Wright return test_case 215aa34f9e7SJames Wright 216aa34f9e7SJames Wright def teardown(self): 217aa34f9e7SJames Wright self.exp.stop(self.database) 218aa34f9e7SJames Wright os.chdir(self.original_path) 219aa34f9e7SJames Wright 220aa34f9e7SJames Wright 221aa34f9e7SJames Wrightif __name__ == "__main__": 222aa34f9e7SJames Wright parser = argparse.ArgumentParser('Testing script for SmartSim integration') 223aa34f9e7SJames Wright parser.add_argument( 224aa34f9e7SJames Wright '-c', 225aa34f9e7SJames Wright '--ceed-backends', 226aa34f9e7SJames Wright type=str, 227aa34f9e7SJames Wright nargs='*', 228aa34f9e7SJames Wright default=['/cpu/self'], 229aa34f9e7SJames Wright help='libCEED backend to use with convergence tests') 230aa34f9e7SJames Wright args = parser.parse_args() 231aa34f9e7SJames Wright 232aa34f9e7SJames Wright test_dir = file_dir / "smartsim_test_dir" 233aa34f9e7SJames Wright print("Setting up database...", end='') 234aa34f9e7SJames Wright test_framework = SmartSimTest(test_dir) 235aa34f9e7SJames Wright test_framework.setup() 236aa34f9e7SJames Wright print(" Done!") 237aa34f9e7SJames Wright for ceed_resource in args.ceed_backends: 238aa34f9e7SJames Wright print("working on " + ceed_resource + ' ...', end='') 239aa34f9e7SJames Wright passTest, exception, _ = test_framework.test(ceed_resource) 240aa34f9e7SJames Wright 241aa34f9e7SJames Wright if passTest: 242aa34f9e7SJames Wright print("Passed!") 243aa34f9e7SJames Wright else: 244aa34f9e7SJames Wright print("Failed!", file=sys.stderr) 245aa34f9e7SJames Wright print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr) 246aa34f9e7SJames Wright 247aa34f9e7SJames Wright print("Cleaning up database...", end='') 248aa34f9e7SJames Wright test_framework.teardown() 249aa34f9e7SJames Wright print(" Done!") 250