xref: /libCEED/rust/libceed-sys/c-src/backends/magma/tuning/generate_tuning.py (revision 78d85032af795416bc0e2fde983867ebd52e264b)
126bdecf3SSebastian Grimberg#!/usr/bin/env python3
226bdecf3SSebastian Grimberg
326bdecf3SSebastian Grimberg# Copyright (c) 2017-2018, Lawrence Livermore National Security, LLC.
426bdecf3SSebastian Grimberg# Produced at the Lawrence Livermore National Laboratory. LLNL-CODE-734707.
526bdecf3SSebastian Grimberg# All Rights reserved. See files LICENSE and NOTICE for details.
626bdecf3SSebastian Grimberg#
726bdecf3SSebastian Grimberg# This file is part of CEED, a collection of benchmarks, miniapps, software
826bdecf3SSebastian Grimberg# libraries and APIs for efficient high-order finite element and spectral
926bdecf3SSebastian Grimberg# element discretizations for exascale applications. For more information and
1026bdecf3SSebastian Grimberg# source code availability see http://github.com/ceed
1126bdecf3SSebastian Grimberg#
1226bdecf3SSebastian Grimberg# The CEED research is supported by the Exascale Computing Project 17-SC-20-SC,
1326bdecf3SSebastian Grimberg# a collaborative effort of two U.S. Department of Energy organizations (Office
1426bdecf3SSebastian Grimberg# of Science and the National Nuclear Security Administration) responsible for
1526bdecf3SSebastian Grimberg# the planning and preparation of a capable exascale ecosystem, including
1626bdecf3SSebastian Grimberg# software, applications, hardware, advanced system engineering and early
1726bdecf3SSebastian Grimberg# testbed platforms, in support of the nation's exascale computing imperative.
1826bdecf3SSebastian Grimberg
1926bdecf3SSebastian Grimbergimport argparse
2026bdecf3SSebastian Grimbergimport os
21acc0bb12SSebastian Grimbergimport glob
2226bdecf3SSebastian Grimbergimport re
23acc0bb12SSebastian Grimbergimport shutil
2426bdecf3SSebastian Grimbergimport subprocess
2526bdecf3SSebastian Grimbergimport pandas as pd
2626bdecf3SSebastian Grimbergimport time
2726bdecf3SSebastian Grimberg
2826bdecf3SSebastian Grimbergscript_dir = os.path.dirname(os.path.realpath(__file__))
2926bdecf3SSebastian Grimberg
3026bdecf3SSebastian Grimberg
31acc0bb12SSebastian Grimbergdef benchmark(nb, build_cmd, backend, log):
32acc0bb12SSebastian Grimberg    # Build for new NB
33acc0bb12SSebastian Grimberg    ceed_magma_h = f"{script_dir}/../ceed-magma.h"
34acc0bb12SSebastian Grimberg    shutil.copyfile(ceed_magma_h, ceed_magma_h + ".backup")
35acc0bb12SSebastian Grimberg    with open(ceed_magma_h, "r") as f:
3626bdecf3SSebastian Grimberg        data = f.read()
3726bdecf3SSebastian Grimberg        data = re.sub(
38acc0bb12SSebastian Grimberg            r".*(#define ceed_magma_queue_sync\(\.\.\.\)).*",
39acc0bb12SSebastian Grimberg            r"\1 " +
40acc0bb12SSebastian Grimberg            ("hipDeviceSynchronize()" if "hip" in backend else "cudaDeviceSynchronize()"),
4126bdecf3SSebastian Grimberg            data)
42acc0bb12SSebastian Grimberg    with open(ceed_magma_h, "w") as f:
4326bdecf3SSebastian Grimberg        f.write(data)
44acc0bb12SSebastian Grimberg
45acc0bb12SSebastian Grimberg    ceed_magma_gemm_selector_cpp = f"{script_dir}/../ceed-magma-gemm-selector.cpp"
46acc0bb12SSebastian Grimberg    shutil.copyfile(
47acc0bb12SSebastian Grimberg        ceed_magma_gemm_selector_cpp,
48acc0bb12SSebastian Grimberg        ceed_magma_gemm_selector_cpp +
49acc0bb12SSebastian Grimberg        ".backup")
50acc0bb12SSebastian Grimberg    with open(ceed_magma_gemm_selector_cpp, "r") as f:
51acc0bb12SSebastian Grimberg        data = f.read()
52acc0bb12SSebastian Grimberg        data = re.sub(
53acc0bb12SSebastian Grimberg            ".*(#define CEED_AUTOTUNE_RTC_NB).*",
54acc0bb12SSebastian Grimberg            r"\1 " + f"{nb}",
55acc0bb12SSebastian Grimberg            data)
56acc0bb12SSebastian Grimberg    with open(ceed_magma_gemm_selector_cpp, "w") as f:
57acc0bb12SSebastian Grimberg        f.write(data)
58acc0bb12SSebastian Grimberg
5926bdecf3SSebastian Grimberg    subprocess.run(build_cmd, cwd=f"{script_dir}/../../..")
60acc0bb12SSebastian Grimberg    subprocess.run(["make", "tuning", "OPT=-O0"], cwd=f"{script_dir}")
61acc0bb12SSebastian Grimberg    shutil.move(ceed_magma_h + ".backup", ceed_magma_h)
62acc0bb12SSebastian Grimberg    shutil.move(ceed_magma_gemm_selector_cpp +
63acc0bb12SSebastian Grimberg                ".backup", ceed_magma_gemm_selector_cpp)
6426bdecf3SSebastian Grimberg
65acc0bb12SSebastian Grimberg    # Run the benchmark
66acc0bb12SSebastian Grimberg    with open(log, "w") as f:
67acc0bb12SSebastian Grimberg        process = subprocess.run(
68acc0bb12SSebastian Grimberg            [f"{script_dir}/tuning", f"{backend}"], stdout=f, stderr=f)
69acc0bb12SSebastian Grimberg    csv = pd.read_csv(
70acc0bb12SSebastian Grimberg        log,
71acc0bb12SSebastian Grimberg        header=None,
72acc0bb12SSebastian Grimberg        delim_whitespace=True,
73acc0bb12SSebastian Grimberg        names=[
74acc0bb12SSebastian Grimberg            "P",
75acc0bb12SSebastian Grimberg            "Q",
76acc0bb12SSebastian Grimberg            "N",
77acc0bb12SSebastian Grimberg            "Q_COMP",
78acc0bb12SSebastian Grimberg            "TRANS",
79acc0bb12SSebastian Grimberg            "MFLOPS"])
80acc0bb12SSebastian Grimberg    return csv
8126bdecf3SSebastian Grimberg
8226bdecf3SSebastian Grimberg
8326bdecf3SSebastian Grimbergif __name__ == "__main__":
8426bdecf3SSebastian Grimberg    # Command line arguments
8526bdecf3SSebastian Grimberg    parser = argparse.ArgumentParser("MAGMA RTC autotuning")
8626bdecf3SSebastian Grimberg    parser.add_argument(
8726bdecf3SSebastian Grimberg        "-arch",
8826bdecf3SSebastian Grimberg        help="Device architecture name for tuning data",
8926bdecf3SSebastian Grimberg        required=True)
9026bdecf3SSebastian Grimberg    parser.add_argument(
9126bdecf3SSebastian Grimberg        "-max-nb",
9226bdecf3SSebastian Grimberg        help="Maximum block size NB to consider for autotuning",
9326bdecf3SSebastian Grimberg        default=32,
9426bdecf3SSebastian Grimberg        type=int)
9526bdecf3SSebastian Grimberg    parser.add_argument(
9626bdecf3SSebastian Grimberg        "-build-cmd",
9726bdecf3SSebastian Grimberg        help="Command used to build libCEED from the source root directory",
9826bdecf3SSebastian Grimberg        default="make")
99acc0bb12SSebastian Grimberg    parser.add_argument(
100acc0bb12SSebastian Grimberg        "-ceed",
101acc0bb12SSebastian Grimberg        help="Ceed resource specifier",
102acc0bb12SSebastian Grimberg        default="/cpu/self")
10326bdecf3SSebastian Grimberg    args = parser.parse_args()
10426bdecf3SSebastian Grimberg
105*78d85032SSebastian Grimberg    nb = 1
106*78d85032SSebastian Grimberg    while nb <= args.max_nb:
10726bdecf3SSebastian Grimberg        # Run the benchmarks
10826bdecf3SSebastian Grimberg        start = time.perf_counter()
109acc0bb12SSebastian Grimberg        data_nb = benchmark(nb, args.build_cmd, args.ceed,
110acc0bb12SSebastian Grimberg                            f"{script_dir}/output-nb-{nb}.txt")
11126bdecf3SSebastian Grimberg        print(
11226bdecf3SSebastian Grimberg            f"Finished benchmarks for NB = {nb}, backend = {args.ceed} ({time.perf_counter() - start} s)")
11326bdecf3SSebastian Grimberg
11426bdecf3SSebastian Grimberg        # Save the data for the highest performing NB
11526bdecf3SSebastian Grimberg        if nb == 1:
11626bdecf3SSebastian Grimberg            data = pd.DataFrame(data_nb)
117acc0bb12SSebastian Grimberg            data["NB"] = nb
11826bdecf3SSebastian Grimberg        else:
119acc0bb12SSebastian Grimberg            idx = data_nb["MFLOPS"] > 1.05 * data["MFLOPS"]
120acc0bb12SSebastian Grimberg            data.loc[idx, "NB"] = nb
121acc0bb12SSebastian Grimberg            data.loc[idx, "MFLOPS"] = data_nb.loc[idx, "MFLOPS"]
12226bdecf3SSebastian Grimberg
123*78d85032SSebastian Grimberg        # Speed up the search by considering only some values on NB
124*78d85032SSebastian Grimberg        if nb < 2:
125*78d85032SSebastian Grimberg            nb *= 2
126*78d85032SSebastian Grimberg        elif nb < 8:
127*78d85032SSebastian Grimberg            nb += 2
128*78d85032SSebastian Grimberg        else:
129*78d85032SSebastian Grimberg            nb += 4
130*78d85032SSebastian Grimberg
13126bdecf3SSebastian Grimberg    # Print the results
132acc0bb12SSebastian Grimberg    with open(f"{script_dir}/{args.arch}_rtc.h", "w") as f:
13326bdecf3SSebastian Grimberg        f.write(
13426bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
13526bdecf3SSebastian Grimberg        f.write(f"// auto-generated from data on {args.arch}\n\n")
13626bdecf3SSebastian Grimberg
137acc0bb12SSebastian Grimberg        rows = data.loc[data["TRANS"] == 1].to_string(header=False, index=False, justify="right", columns=[
138acc0bb12SSebastian Grimberg                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
13926bdecf3SSebastian Grimberg        f.write(
14026bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
14126bdecf3SSebastian Grimberg        f.write(
14226bdecf3SSebastian Grimberg            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_t_{args.arch}" +
14326bdecf3SSebastian Grimberg            " = {\n")
14426bdecf3SSebastian Grimberg        count = 0
14526bdecf3SSebastian Grimberg        for row in rows:
146acc0bb12SSebastian Grimberg            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
14726bdecf3SSebastian Grimberg                    ("},\n" if count < len(rows) - 1 else "}\n"))
14826bdecf3SSebastian Grimberg            count += 1
14926bdecf3SSebastian Grimberg        f.write("};\n\n")
15026bdecf3SSebastian Grimberg
151acc0bb12SSebastian Grimberg        rows = data.loc[data["TRANS"] == 0].to_string(header=False, index=False, justify="right", columns=[
152acc0bb12SSebastian Grimberg                                                      "P", "Q", "N", "Q_COMP", "NB"]).split("\n")
15326bdecf3SSebastian Grimberg        f.write(
15426bdecf3SSebastian Grimberg            "////////////////////////////////////////////////////////////////////////////////\n")
15526bdecf3SSebastian Grimberg        f.write(
15626bdecf3SSebastian Grimberg            f"std::vector<std::array<int, RECORD_LENGTH_RTC> > drtc_n_{args.arch}" +
15726bdecf3SSebastian Grimberg            " = {\n")
15826bdecf3SSebastian Grimberg        count = 0
15926bdecf3SSebastian Grimberg        for row in rows:
160acc0bb12SSebastian Grimberg            f.write("    {" + re.sub(r"([0-9])(\s+)", r"\1,\2", row) +
16126bdecf3SSebastian Grimberg                    ("},\n" if count < len(rows) - 1 else "}\n"))
16226bdecf3SSebastian Grimberg            count += 1
16326bdecf3SSebastian Grimberg        f.write("};\n")
164