xref: /honee/tests/smartsim_regression_framework.py (revision aa34f9e7d30747aa1ffbf61951038d9c23619016)
1*aa34f9e7SJames Wright#!/usr/bin/env python3
2*aa34f9e7SJames Wrightfrom junit_xml import TestCase
3*aa34f9e7SJames Wrightfrom smartsim import Experiment
4*aa34f9e7SJames Wrightfrom smartsim.settings import RunSettings
5*aa34f9e7SJames Wrightfrom smartredis import Client
6*aa34f9e7SJames Wrightimport numpy as np
7*aa34f9e7SJames Wrightfrom pathlib import Path
8*aa34f9e7SJames Wrightimport argparse
9*aa34f9e7SJames Wrightimport traceback
10*aa34f9e7SJames Wrightimport sys
11*aa34f9e7SJames Wrightimport time
12*aa34f9e7SJames Wrightfrom typing import Tuple
13*aa34f9e7SJames Wrightimport os
14*aa34f9e7SJames Wrightimport shutil
15*aa34f9e7SJames Wrightimport logging
16*aa34f9e7SJames Wrightimport socket
17*aa34f9e7SJames Wright
18*aa34f9e7SJames Wright# autopep8 off
19*aa34f9e7SJames Wrightsys.path.insert(0, (Path(__file__).parents[3] / "tests/junit-xml").as_posix())
20*aa34f9e7SJames Wright# autopep8 on
21*aa34f9e7SJames Wright
22*aa34f9e7SJames Wrightlogging.disable(logging.WARNING)
23*aa34f9e7SJames Wright
24*aa34f9e7SJames Wrightfile_dir = Path(__file__).parent.absolute()
25*aa34f9e7SJames Wrighttest_output_dir = Path(__file__).parent.absolute() / 'output'
26*aa34f9e7SJames Wright
27*aa34f9e7SJames Wright
28*aa34f9e7SJames Wrightdef getOpenSocket():
29*aa34f9e7SJames Wright    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
30*aa34f9e7SJames Wright    s.bind(('', 0))
31*aa34f9e7SJames Wright    addr = s.getsockname()
32*aa34f9e7SJames Wright    s.close()
33*aa34f9e7SJames Wright    return addr[1]
34*aa34f9e7SJames Wright
35*aa34f9e7SJames Wright
36*aa34f9e7SJames Wrightclass NoError(Exception):
37*aa34f9e7SJames Wright    pass
38*aa34f9e7SJames Wright
39*aa34f9e7SJames Wright
40*aa34f9e7SJames Wrightdef assert_np_all(test, truth):
41*aa34f9e7SJames Wright    """Assert with better error reporting"""
42*aa34f9e7SJames Wright    try:
43*aa34f9e7SJames Wright        assert np.all(test == truth)
44*aa34f9e7SJames Wright    except Exception as e:
45*aa34f9e7SJames Wright        raise Exception(f"Expected {truth}, but got {test}") from e
46*aa34f9e7SJames Wright
47*aa34f9e7SJames Wright
48*aa34f9e7SJames Wrightdef assert_equal(test, truth):
49*aa34f9e7SJames Wright    """Assert with better error reporting"""
50*aa34f9e7SJames Wright    try:
51*aa34f9e7SJames Wright        assert test == truth
52*aa34f9e7SJames Wright    except Exception as e:
53*aa34f9e7SJames Wright        raise Exception(f"Expected {truth}, but got {test}") from e
54*aa34f9e7SJames Wright
55*aa34f9e7SJames Wright
56*aa34f9e7SJames Wrightdef verify_training_data(database_array, correct_array, ceed_resource, atol=1e-8, rtol=1e-8):
57*aa34f9e7SJames Wright    """Verify the training data
58*aa34f9e7SJames Wright
59*aa34f9e7SJames Wright    Cannot just use np.allclose due to vorticity vector directionality.
60*aa34f9e7SJames Wright    Check whether the S-frame-oriented vorticity vector's second component is just flipped.
61*aa34f9e7SJames Wright    This can happen due to the eigenvector ordering changing based on whichever one is closest to the vorticity vector.
62*aa34f9e7SJames Wright    If two eigenvectors are very close to the vorticity vector, this can cause the ordering to flip.
63*aa34f9e7SJames Wright    This flipping of the vorticity vector is not incorrect, just a known sensitivity of the model.
64*aa34f9e7SJames Wright    """
65*aa34f9e7SJames Wright    if not np.allclose(database_array, correct_array, atol=atol, rtol=rtol):
66*aa34f9e7SJames Wright
67*aa34f9e7SJames Wright        total_tolerances = atol + rtol * np.abs(correct_array)  # mimic np.allclose tolerance calculation
68*aa34f9e7SJames Wright        idx_notclose = np.where(np.abs(database_array - correct_array) > total_tolerances)
69*aa34f9e7SJames Wright        if not np.all(idx_notclose[1] == 4):
70*aa34f9e7SJames Wright            # values other than vorticity are not close
71*aa34f9e7SJames Wright            test_fail = True
72*aa34f9e7SJames Wright        else:
73*aa34f9e7SJames Wright            database_vorticity = database_array[idx_notclose]
74*aa34f9e7SJames Wright            correct_vorticity = correct_array[idx_notclose]
75*aa34f9e7SJames Wright            test_fail = False if np.allclose(-database_vorticity, correct_vorticity,
76*aa34f9e7SJames Wright                                             atol=atol, rtol=rtol) else True
77*aa34f9e7SJames Wright
78*aa34f9e7SJames Wright        if test_fail:
79*aa34f9e7SJames Wright            database_output_path = Path(
80*aa34f9e7SJames Wright                f"./y0_database_values_{ceed_resource.replace('/', '_')}.npy").absolute()
81*aa34f9e7SJames Wright            np.save(database_output_path, database_array)
82*aa34f9e7SJames Wright            raise AssertionError(f"Array values in database max difference: {np.max(np.abs(correct_array - database_array))}\n"
83*aa34f9e7SJames Wright                                 f"Array saved to {database_output_path.as_posix()}")
84*aa34f9e7SJames Wright
85*aa34f9e7SJames Wright
86*aa34f9e7SJames Wrightclass SmartSimTest(object):
87*aa34f9e7SJames Wright
88*aa34f9e7SJames Wright    def __init__(self, directory_path: Path):
89*aa34f9e7SJames Wright        self.exp: Experiment
90*aa34f9e7SJames Wright        self.database = None
91*aa34f9e7SJames Wright        self.directory_path: Path = directory_path
92*aa34f9e7SJames Wright        self.original_path: Path
93*aa34f9e7SJames Wright
94*aa34f9e7SJames Wright    def setup(self):
95*aa34f9e7SJames Wright        """To create the test directory and start SmartRedis database"""
96*aa34f9e7SJames Wright        self.original_path = Path(os.getcwd())
97*aa34f9e7SJames Wright
98*aa34f9e7SJames Wright        if self.directory_path.exists() and self.directory_path.is_dir():
99*aa34f9e7SJames Wright            shutil.rmtree(self.directory_path)
100*aa34f9e7SJames Wright        self.directory_path.mkdir()
101*aa34f9e7SJames Wright        os.chdir(self.directory_path)
102*aa34f9e7SJames Wright
103*aa34f9e7SJames Wright        PORT = getOpenSocket()
104*aa34f9e7SJames Wright        self.exp = Experiment("test", launcher="local")
105*aa34f9e7SJames Wright        self.database = self.exp.create_database(port=PORT, batch=False, interface="lo")
106*aa34f9e7SJames Wright        self.exp.generate(self.database)
107*aa34f9e7SJames Wright        self.exp.start(self.database)
108*aa34f9e7SJames Wright
109*aa34f9e7SJames Wright        # SmartRedis will complain if these aren't set
110*aa34f9e7SJames Wright        os.environ['SR_LOG_FILE'] = 'R'
111*aa34f9e7SJames Wright        os.environ['SR_LOG_LEVEL'] = 'INFO'
112*aa34f9e7SJames Wright
113*aa34f9e7SJames Wright    def test(self, ceed_resource) -> Tuple[bool, Exception, str]:
114*aa34f9e7SJames Wright        client = None
115*aa34f9e7SJames Wright        arguments = []
116*aa34f9e7SJames Wright        exe_path = "../../build/navierstokes"
117*aa34f9e7SJames Wright        try:
118*aa34f9e7SJames Wright            arguments = [
119*aa34f9e7SJames Wright                '-ceed', ceed_resource,
120*aa34f9e7SJames Wright                '-options_file', (file_dir / '../blasius.yaml').as_posix(),
121*aa34f9e7SJames Wright                '-ts_max_steps', '2',
122*aa34f9e7SJames Wright                '-diff_filter_grid_based_width',
123*aa34f9e7SJames Wright                '-ts_monitor', '-snes_monitor',
124*aa34f9e7SJames Wright                '-diff_filter_ksp_max_it', '50', '-diff_filter_ksp_monitor',
125*aa34f9e7SJames Wright                '-degree', '1',
126*aa34f9e7SJames Wright                '-sgs_train_enable',
127*aa34f9e7SJames Wright                '-sgs_train_write_data_interval', '2',
128*aa34f9e7SJames Wright                '-sgs_train_filter_width_scales', '1.2,3.1',
129*aa34f9e7SJames Wright                '-bc_symmetry_z',
130*aa34f9e7SJames Wright                '-dm_plex_shape', 'zbox',
131*aa34f9e7SJames Wright                '-dm_plex_box_bd', 'none,none,periodic',
132*aa34f9e7SJames Wright                '-dm_plex_box_faces', '4,6,1',
133*aa34f9e7SJames Wright                '-mesh_transform',
134*aa34f9e7SJames Wright            ]
135*aa34f9e7SJames Wright
136*aa34f9e7SJames Wright            run_settings = RunSettings(exe_path, exe_args=arguments)
137*aa34f9e7SJames Wright
138*aa34f9e7SJames Wright            client_exp = self.exp.create_model(f"client_{ceed_resource.replace('/', '_')}", run_settings)
139*aa34f9e7SJames Wright
140*aa34f9e7SJames Wright            # Start the client model
141*aa34f9e7SJames Wright            self.exp.start(client_exp, summary=False, block=True)
142*aa34f9e7SJames Wright
143*aa34f9e7SJames Wright            client = Client(cluster=False, address=self.database.get_address()[0])
144*aa34f9e7SJames Wright
145*aa34f9e7SJames Wright            assert client.poll_tensor("sizeInfo", 250, 5)
146*aa34f9e7SJames Wright            assert_np_all(client.get_tensor("sizeInfo"), np.array([35, 12, 6, 1, 1, 0]))
147*aa34f9e7SJames Wright
148*aa34f9e7SJames Wright            assert client.poll_tensor("check-run", 250, 5)
149*aa34f9e7SJames Wright            assert_equal(client.get_tensor("check-run")[0], 1)
150*aa34f9e7SJames Wright
151*aa34f9e7SJames Wright            assert client.poll_tensor("tensor-ow", 250, 5)
152*aa34f9e7SJames Wright            assert_equal(client.get_tensor("tensor-ow")[0], 1)
153*aa34f9e7SJames Wright
154*aa34f9e7SJames Wright            assert client.poll_tensor("num_filter_widths", 250, 5)
155*aa34f9e7SJames Wright            assert_equal(client.get_tensor("num_filter_widths")[0], 2)
156*aa34f9e7SJames Wright
157*aa34f9e7SJames Wright            assert client.poll_tensor("step", 250, 5)
158*aa34f9e7SJames Wright            assert_equal(client.get_tensor("step")[0], 2)
159*aa34f9e7SJames Wright
160*aa34f9e7SJames Wright            assert client.poll_tensor("y.0.0", 250, 5)
161*aa34f9e7SJames Wright            test_data_path = test_output_dir / "y00_output.npy"
162*aa34f9e7SJames Wright            assert test_data_path.is_file()
163*aa34f9e7SJames Wright            correct_value = np.load(test_data_path)
164*aa34f9e7SJames Wright            database_value = client.get_tensor("y.0.0")
165*aa34f9e7SJames Wright            verify_training_data(database_value, correct_value, ceed_resource)
166*aa34f9e7SJames Wright
167*aa34f9e7SJames Wright            assert client.poll_tensor("y.0.1", 250, 5)
168*aa34f9e7SJames Wright            test_data_path = test_output_dir / "y01_output.npy"
169*aa34f9e7SJames Wright            assert test_data_path.is_file()
170*aa34f9e7SJames Wright            correct_value = np.load(test_data_path)
171*aa34f9e7SJames Wright            database_value = client.get_tensor("y.0.1")
172*aa34f9e7SJames Wright            verify_training_data(database_value, correct_value, ceed_resource)
173*aa34f9e7SJames Wright
174*aa34f9e7SJames Wright            client.flush_db([os.environ["SSDB"]])
175*aa34f9e7SJames Wright            output = (True, NoError(), exe_path + ' ' + ' '.join(arguments))
176*aa34f9e7SJames Wright        except Exception as e:
177*aa34f9e7SJames Wright            output = (False, e, exe_path + ' ' + ' '.join(arguments))
178*aa34f9e7SJames Wright
179*aa34f9e7SJames Wright        finally:
180*aa34f9e7SJames Wright            if client:
181*aa34f9e7SJames Wright                client.flush_db([os.environ["SSDB"]])
182*aa34f9e7SJames Wright
183*aa34f9e7SJames Wright        return output
184*aa34f9e7SJames Wright
185*aa34f9e7SJames Wright    def test_junit(self, ceed_resource):
186*aa34f9e7SJames Wright        start: float = time.time()
187*aa34f9e7SJames Wright
188*aa34f9e7SJames Wright        passTest, exception, args = self.test(ceed_resource)
189*aa34f9e7SJames Wright
190*aa34f9e7SJames Wright        output = "" if isinstance(exception, NoError) else ''.join(
191*aa34f9e7SJames Wright            traceback.TracebackException.from_exception(exception).format())
192*aa34f9e7SJames Wright
193*aa34f9e7SJames Wright        test_case = TestCase(f'SmartSim Test {ceed_resource}',
194*aa34f9e7SJames Wright                             elapsed_sec=time.time() - start,
195*aa34f9e7SJames Wright                             timestamp=time.strftime(
196*aa34f9e7SJames Wright                                 '%Y-%m-%d %H:%M:%S %Z', time.localtime(start)),
197*aa34f9e7SJames Wright                             stdout=output,
198*aa34f9e7SJames Wright                             stderr=output,
199*aa34f9e7SJames Wright                             allow_multiple_subelements=True,
200*aa34f9e7SJames Wright                             category=f'SmartSim Tests')
201*aa34f9e7SJames Wright        test_case.args = args
202*aa34f9e7SJames Wright        if not passTest and 'occa' in ceed_resource:
203*aa34f9e7SJames Wright            test_case.add_skipped_info("OCCA mode not supported")
204*aa34f9e7SJames Wright        elif not passTest:
205*aa34f9e7SJames Wright            test_case.add_failure_info("exception", output)
206*aa34f9e7SJames Wright
207*aa34f9e7SJames Wright        return test_case
208*aa34f9e7SJames Wright
209*aa34f9e7SJames Wright    def teardown(self):
210*aa34f9e7SJames Wright        self.exp.stop(self.database)
211*aa34f9e7SJames Wright        os.chdir(self.original_path)
212*aa34f9e7SJames Wright
213*aa34f9e7SJames Wright
214*aa34f9e7SJames Wrightif __name__ == "__main__":
215*aa34f9e7SJames Wright    parser = argparse.ArgumentParser('Testing script for SmartSim integration')
216*aa34f9e7SJames Wright    parser.add_argument(
217*aa34f9e7SJames Wright        '-c',
218*aa34f9e7SJames Wright        '--ceed-backends',
219*aa34f9e7SJames Wright        type=str,
220*aa34f9e7SJames Wright        nargs='*',
221*aa34f9e7SJames Wright        default=['/cpu/self'],
222*aa34f9e7SJames Wright        help='libCEED backend to use with convergence tests')
223*aa34f9e7SJames Wright    args = parser.parse_args()
224*aa34f9e7SJames Wright
225*aa34f9e7SJames Wright    test_dir = file_dir / "smartsim_test_dir"
226*aa34f9e7SJames Wright    print("Setting up database...", end='')
227*aa34f9e7SJames Wright    test_framework = SmartSimTest(test_dir)
228*aa34f9e7SJames Wright    test_framework.setup()
229*aa34f9e7SJames Wright    print(" Done!")
230*aa34f9e7SJames Wright    for ceed_resource in args.ceed_backends:
231*aa34f9e7SJames Wright        print("working on " + ceed_resource + ' ...', end='')
232*aa34f9e7SJames Wright        passTest, exception, _ = test_framework.test(ceed_resource)
233*aa34f9e7SJames Wright
234*aa34f9e7SJames Wright        if passTest:
235*aa34f9e7SJames Wright            print("Passed!")
236*aa34f9e7SJames Wright        else:
237*aa34f9e7SJames Wright            print("Failed!", file=sys.stderr)
238*aa34f9e7SJames Wright            print('\t' + ''.join(traceback.TracebackException.from_exception(exception).format()), file=sys.stderr)
239*aa34f9e7SJames Wright
240*aa34f9e7SJames Wright    print("Cleaning up database...", end='')
241*aa34f9e7SJames Wright    test_framework.teardown()
242*aa34f9e7SJames Wright    print(" Done!")
243