{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G4yBrceuFbf3"
},
"source": [
"# ColabFold/AlphaFold2 Notebook\n",
"\n",
"\n",
"\n",
"## ColabFold v1.5.3: AlphaFold2 using MMseqs2\n",
"\n",
"Easy to use protein structure and complex prediction using [AlphaFold2](https://www.nature.com/articles/s41586-021-03819-2) and [Alphafold2-multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). Sequence alignments/templates are generated through [MMseqs2](mmseqs.com) and [HHsearch](https://github.com/soedinglab/hh-suite). For more details, see bottom of the notebook, checkout the [ColabFold GitHub](https://github.com/sokrypton/ColabFold) and read our manuscript.\n",
"Old versions: [v1.4](https://colab.research.google.com/github/sokrypton/ColabFold/blob/v1.4.0/AlphaFold2.ipynb), [v1.5.1](https://colab.research.google.com/github/sokrypton/ColabFold/blob/v1.5.1/AlphaFold2.ipynb), [v1.5.2](https://colab.research.google.com/github/sokrypton/ColabFold/blob/v1.5.2/AlphaFold2.ipynb)\n",
"\n",
"[Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S, Steinegger M. ColabFold: Making protein folding accessible to all.\n",
"*Nature Methods*, 2022](https://www.nature.com/articles/s41592-022-01488-1)\n",
"\n",
"-----------\n",
"\n",
"### News\n",
"- 2023/07/31: The ColabFold MSA server is back to normal. It was using older DB (UniRef30 2202/PDB70 220313) from 27th ~8:30 AM CEST to 31st ~11:10 AM CEST.\n",
"- 2023/06/12: New databases! UniRef30 updated to 2023_02 and PDB to 230517. We now use PDB100 instead of PDB70 (see [notes](#pdb100)).\n",
"- 2023/06/12: We introduced a new default pairing strategy: Previously, for multimer predictions with more than 2 chains, we only pair if all sequences taxonomically match (\"complete\" pairing). The new default \"greedy\" strategy pairs any taxonomically matching subsets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "kOblAo-xetgx",
"jupyter": {
"outputs_hidden": true
},
"outputId": "fd5d6895-cd56-4678-8e88-cb82cca50871"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jobname test_a5e17\n",
"sequence PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK\n",
"length 59\n"
]
}
],
"source": [
"# @title Input protein sequence(s), then hit `Runtime` -> `Run all` { display-mode: \"form\" }\n",
"from google.colab import files\n",
"import os\n",
"import re\n",
"import hashlib\n",
"import random\n",
"\n",
"from sys import version_info\n",
"python_version = f\"{version_info.major}.{version_info.minor}\"\n",
"\n",
"def add_hash(x,y):\n",
" return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
"jobname = 'test' #@param {type:\"string\"}\n",
"# number of models to use\n",
"num_relax = 0 #@param [0, 1, 5] {type:\"raw\"}\n",
"#@markdown - specify how many of the top ranked structures to relax using amber\n",
"template_mode = \"none\" #@param [\"none\", \"pdb100\",\"custom\"]\n",
"#@markdown - `none` = no template information is used. `pdb100` = detect templates in pdb100 (see [notes](#pdb100)). `custom` - upload and search own templates (PDB or mmCIF format, see [notes](#custom_templates))\n",
"\n",
"use_amber = num_relax > 0\n",
"\n",
"# remove whitespaces\n",
"query_sequence = \"\".join(query_sequence.split())\n",
"\n",
"basejobname = \"\".join(jobname.split())\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"jobname = add_hash(basejobname, query_sequence)\n",
"\n",
"# check if directory with jobname exists\n",
"def check(folder):\n",
" if os.path.exists(folder):\n",
" return False\n",
" else:\n",
" return True\n",
"if not check(jobname):\n",
" n = 0\n",
" while not check(f\"{jobname}_{n}\"): n += 1\n",
" jobname = f\"{jobname}_{n}\"\n",
"\n",
"# make directory to save results\n",
"os.makedirs(jobname, exist_ok=True)\n",
"\n",
"# save queries\n",
"queries_path = os.path.join(jobname, f\"{jobname}.csv\")\n",
"with open(queries_path, \"w\") as text_file:\n",
" text_file.write(f\"id,sequence\\n{jobname},{query_sequence}\")\n",
"\n",
"if template_mode == \"pdb100\":\n",
" use_templates = True\n",
" custom_template_path = None\n",
"elif template_mode == \"custom\":\n",
" custom_template_path = os.path.join(jobname,f\"template\")\n",
" os.makedirs(custom_template_path, exist_ok=True)\n",
" uploaded = files.upload()\n",
" use_templates = True\n",
" for fn in uploaded.keys():\n",
" os.rename(fn,os.path.join(custom_template_path,fn))\n",
"else:\n",
" custom_template_path = None\n",
" use_templates = False\n",
"\n",
"print(\"jobname\",jobname)\n",
"print(\"sequence\",query_sequence)\n",
"print(\"length\",len(query_sequence.replace(\":\",\"\")))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "AzIKiDiCaHAn",
"jupyter": {
"outputs_hidden": true
},
"outputId": "5b6b9ab0-af96-45ef-c903-43ae13525af0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing colabfold...\n",
"CPU times: user 114 ms, sys: 27.2 ms, total: 142 ms\n",
"Wall time: 42.9 s\n"
]
}
],
"source": [
"# @title Install dependencies { display-mode: \"form\" }\n",
"%%time\n",
"import os\n",
"USE_AMBER = use_amber\n",
"USE_TEMPLATES = use_templates\n",
"PYTHON_VERSION = python_version\n",
"\n",
"if not os.path.isfile(\"COLABFOLD_READY\"):\n",
" print(\"installing colabfold...\")\n",
" os.system(\"pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'\")\n",
" os.system(\"pip install --upgrade dm-haiku\")\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\")\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold\")\n",
" # patch for jax > 0.3.25\n",
" os.system(\"sed -i 's/weights = jax.nn.softmax(logits)/logits=jnp.clip(logits,-1e8,1e8);weights=jax.nn.softmax(logits)/g' alphafold/model/modules.py\")\n",
" os.system(\"touch COLABFOLD_READY\")\n",
"\n",
"if USE_AMBER or USE_TEMPLATES:\n",
" if not os.path.isfile(\"CONDA_READY\"):\n",
" print(\"installing conda...\")\n",
" os.system(\"wget -qnc https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh\")\n",
" os.system(\"bash Mambaforge-Linux-x86_64.sh -bfp /usr/local\")\n",
" os.system(\"mamba config --set auto_update_conda false\")\n",
" os.system(\"touch CONDA_READY\")\n",
"\n",
"if USE_TEMPLATES and not os.path.isfile(\"HH_READY\") and USE_AMBER and not os.path.isfile(\"AMBER_READY\"):\n",
" print(\"installing hhsuite and amber...\")\n",
" os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer\")\n",
" os.system(\"touch HH_READY\")\n",
" os.system(\"touch AMBER_READY\")\n",
"else:\n",
" if USE_TEMPLATES and not os.path.isfile(\"HH_READY\"):\n",
" print(\"installing hhsuite...\")\n",
" os.system(f\"mamba install -y -c conda-forge -c bioconda kalign2=2.04 hhsuite=3.3.0 python='{PYTHON_VERSION}'\")\n",
" os.system(\"touch HH_READY\")\n",
" if USE_AMBER and not os.path.isfile(\"AMBER_READY\"):\n",
" print(\"installing amber...\")\n",
" os.system(f\"mamba install -y -c conda-forge openmm=7.7.0 python='{PYTHON_VERSION}' pdbfixer\")\n",
" os.system(\"touch AMBER_READY\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "C2_sh2uAonJH"
},
"outputs": [],
"source": [
"#@markdown ### MSA options (custom MSA upload, single sequence, pairing mode)\n",
"msa_mode = \"mmseqs2_uniref_env\" #@param [\"mmseqs2_uniref_env\", \"mmseqs2_uniref\",\"single_sequence\",\"custom\"]\n",
"pair_mode = \"unpaired_paired\" #@param [\"unpaired_paired\",\"paired\",\"unpaired\"] {type:\"string\"}\n",
"#@markdown - \"unpaired_paired\" = pair sequences from same species + unpaired MSA, \"unpaired\" = seperate MSA for each chain, \"paired\" - only use paired sequences.\n",
"\n",
"# decide which a3m to use\n",
"if \"mmseqs2\" in msa_mode:\n",
" a3m_file = os.path.join(jobname,f\"{jobname}.a3m\")\n",
"\n",
"elif msa_mode == \"custom\":\n",
" a3m_file = os.path.join(jobname,f\"{jobname}.custom.a3m\")\n",
" if not os.path.isfile(a3m_file):\n",
" custom_msa_dict = files.upload()\n",
" custom_msa = list(custom_msa_dict.keys())[0]\n",
" header = 0\n",
" import fileinput\n",
" for line in fileinput.FileInput(custom_msa,inplace=1):\n",
" if line.startswith(\">\"):\n",
" header = header + 1\n",
" if not line.rstrip():\n",
" continue\n",
" if line.startswith(\">\") == False and header == 1:\n",
" query_sequence = line.rstrip()\n",
" print(line, end='')\n",
"\n",
" os.rename(custom_msa, a3m_file)\n",
" queries_path=a3m_file\n",
" print(f\"moving {custom_msa} to {a3m_file}\")\n",
"\n",
"else:\n",
" a3m_file = os.path.join(jobname,f\"{jobname}.single_sequence.a3m\")\n",
" with open(a3m_file, \"w\") as text_file:\n",
" text_file.write(\">1\\n%s\" % query_sequence)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ADDuaolKmjGW"
},
"outputs": [],
"source": [
"# @title { display-mode: \"form\" }\n",
"#@markdown ### Advanced settings\n",
"model_type = \"auto\" #@param [\"auto\", \"alphafold2_ptm\", \"alphafold2_multimer_v1\", \"alphafold2_multimer_v2\", \"alphafold2_multimer_v3\"]\n",
"#@markdown - if `auto` selected, will use `alphafold2_ptm` for monomer prediction and `alphafold2_multimer_v3` for complex prediction.\n",
"#@markdown Any of the mode_types can be used (regardless if input is monomer or complex).\n",
"num_recycles = \"3\" #@param [\"auto\", \"0\", \"1\", \"3\", \"6\", \"12\", \"24\", \"48\"]\n",
"#@markdown - if `auto` selected, will use `num_recycles=20` if `model_type=alphafold2_multimer_v3`, else `num_recycles=3` .\n",
"recycle_early_stop_tolerance = \"auto\" #@param [\"auto\", \"0.0\", \"0.5\", \"1.0\"]\n",
"#@markdown - if `auto` selected, will use `tol=0.5` if `model_type=alphafold2_multimer_v3` else `tol=0.0`.\n",
"relax_max_iterations = 200 #@param [0, 200, 2000] {type:\"raw\"}\n",
"#@markdown - max amber relax iterations, `0` = unlimited (AlphaFold2 default, can take very long)\n",
"pairing_strategy = \"greedy\" #@param [\"greedy\", \"complete\"] {type:\"string\"}\n",
"#@markdown - `greedy` = pair any taxonomically matching subsets, `complete` = all sequences have to match in one line.\n",
"\n",
"\n",
"#@markdown #### Sample settings\n",
"#@markdown - enable dropouts and increase number of seeds to sample predictions from uncertainty of the model.\n",
"#@markdown - decrease `max_msa` to increase uncertainity\n",
"max_msa = \"auto\" #@param [\"auto\", \"512:1024\", \"256:512\", \"64:128\", \"32:64\", \"16:32\"]\n",
"num_seeds = 1 #@param [1,2,4,8,16] {type:\"raw\"}\n",
"use_dropout = False #@param {type:\"boolean\"}\n",
"\n",
"num_recycles = None if num_recycles == \"auto\" else int(num_recycles)\n",
"recycle_early_stop_tolerance = None if recycle_early_stop_tolerance == \"auto\" else float(recycle_early_stop_tolerance)\n",
"if max_msa == \"auto\": max_msa = None\n",
"\n",
"#@markdown #### Save settings\n",
"save_all = False #@param {type:\"boolean\"}\n",
"save_recycles = False #@param {type:\"boolean\"}\n",
"save_to_google_drive = False #@param {type:\"boolean\"}\n",
"#@markdown - if the save_to_google_drive option was selected, the result zip will be uploaded to your Google Drive\n",
"dpi = 200 #@param {type:\"integer\"}\n",
"#@markdown - set dpi for image resolution\n",
"\n",
"if save_to_google_drive:\n",
" from pydrive.drive import GoogleDrive\n",
" from pydrive.auth import GoogleAuth\n",
" from google.colab import auth\n",
" from oauth2client.client import GoogleCredentials\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" drive = GoogleDrive(gauth)\n",
" print(\"You are logged into Google Drive and are good to go!\")\n",
"\n",
"#@markdown Don't forget to hit `Runtime` -> `Run all` after updating the form."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"collapsed": true,
"id": "mbaIO9pWjaN0",
"jupyter": {
"outputs_hidden": true
},
"outputId": "9beb72b5-c2a3-4e28-a80e-f9583095f4ff"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading alphafold2 weights to .: 100%|██████████| 3.47G/3.47G [02:40<00:00, 23.2MB/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-12-07 22:02:42,821 Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: \"rocm\". Available platform names are: CUDA\n",
"2023-12-07 22:02:42,823 Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n",
"2023-12-07 22:02:44,726 Running on GPU\n",
"2023-12-07 22:02:44,904 Found 4 citations for tools or databases\n",
"2023-12-07 22:02:44,904 Query 1/1: test_a5e17 (length 59)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"COMPLETE: 100%|██████████| 150/150 [elapsed: 00:01 remaining: 00:00]\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2023-12-07 22:02:47,182 Setting max_seq=512, max_extra_seq=5120\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdownload_alphafold_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\".\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m results = run(\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0mqueries\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mqueries\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0mresult_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresult_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/colabfold/batch.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(queries, result_dir, num_models, is_complex, num_recycles, recycle_early_stop_tolerance, model_order, num_ensemble, model_type, msa_mode, use_templates, custom_template_path, num_relax, relax_max_iterations, relax_tolerance, relax_stiffness, relax_max_outer_iterations, keep_existing_results, rank_by, pair_mode, pairing_strategy, data_dir, host_url, user_agent, random_seed, num_seeds, recompile_padding, zip_results, prediction_callback, save_single_representations, save_pair_representations, save_all, save_recycles, use_dropout, use_gpu_relax, stop_at_score, dpi, max_seq, max_extra_seq, pdb_hit_file, local_pdb_path, use_cluster_profile, feature_dict_callback, **kwargs)\u001b[0m\n\u001b[1;32m 1568\u001b[0m \u001b[0mfirst_job\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1569\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1570\u001b[0;31m results = predict_structure(\n\u001b[0m\u001b[1;32m 1571\u001b[0m \u001b[0mprefix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjobname\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1572\u001b[0m \u001b[0mresult_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresult_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/colabfold/batch.py\u001b[0m in \u001b[0;36mpredict_structure\u001b[0;34m(prefix, result_dir, feature_dict, is_complex, use_templates, sequences_lengths, pad_len, model_type, model_runner_and_params, num_relax, relax_max_iterations, relax_tolerance, relax_stiffness, relax_max_outer_iterations, rank_by, random_seed, num_seeds, stop_at_score, prediction_callback, use_gpu_relax, save_all, save_single_representations, save_pair_representations, save_recycles)\u001b[0m\n\u001b[1;32m 419\u001b[0m \u001b[0;31m# predict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 420\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecycles\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 421\u001b[0;31m model_runner.predict(input_features,\n\u001b[0m\u001b[1;32m 422\u001b[0m \u001b[0mrandom_seed\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mseed\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[0mreturn_representations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_representations\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/alphafold/model/model.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, feat, random_seed, return_representations, callback)\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;31m# run\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msub_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprev\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msub_key\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msub_feat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprev\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_representations\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/content/alphafold/model/model.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(key, feat, prev)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_jnp_to_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mfeat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"prev\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mprev\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0mprev\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"prev\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprev\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mapi_boundary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcache_miss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 256\u001b[0;31m outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(\n\u001b[0m\u001b[1;32m 257\u001b[0m fun, infer_params_fn, *args, **kwargs)\n\u001b[1;32m 258\u001b[0m \u001b[0mexecutable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_most_recent_pjit_call_executable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_python_pjit_helper\u001b[0;34m(fun, infer_params_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_arg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mout_flat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpjit_p\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mpxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDeviceAssignmentMismatchError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[0mfails\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 2654\u001b[0m top_trace = (top_trace if not axis_main or axis_main.level < top_trace.level\n\u001b[1;32m 2655\u001b[0m else axis_main.with_cur_sublevel())\n\u001b[0;32m-> 2656\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2657\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2658\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 386\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 388\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 389\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 866\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 867\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 868\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 869\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 870\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_call_impl\u001b[0;34m(jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1210\u001b[0m has_explicit_sharding = _pjit_explicit_sharding(\n\u001b[1;32m 1211\u001b[0m in_shardings, out_shardings, None, None)\n\u001b[0;32m-> 1212\u001b[0;31m return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,\n\u001b[0m\u001b[1;32m 1213\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch_registry\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1214\u001b[0m _get_cpp_global_cache(has_explicit_sharding))(*args)\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36mcall_impl_cache_miss\u001b[0;34m(*args_, **kwargs_)\u001b[0m\n\u001b[1;32m 1194\u001b[0m donated_invars, name, keep_unused, inline):\n\u001b[1;32m 1195\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcall_impl_cache_miss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1196\u001b[0;31m out_flat, compiled = _pjit_call_impl_python(\n\u001b[0m\u001b[1;32m 1197\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_shardings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0min_shardings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1198\u001b[0m \u001b[0mout_shardings\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresource_env\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresource_env\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_call_impl_python\u001b[0;34m(jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1127\u001b[0m resource_env.physical_mesh if resource_env is not None else None)\n\u001b[1;32m 1128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1129\u001b[0;31m compiled = _pjit_lower(\n\u001b[0m\u001b[1;32m 1130\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresource_env\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1131\u001b[0m \u001b[0mdonated_invars\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkeep_unused\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minline\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_lower\u001b[0;34m(jaxpr, in_shardings, out_shardings, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1258\u001b[0m \u001b[0min_shardings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSameDeviceAssignmentTuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0min_shardings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mda\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1259\u001b[0m \u001b[0mout_shardings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSameDeviceAssignmentTuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_shardings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mda\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1260\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_pjit_lower_cached\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1261\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1262\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_lower_cached\u001b[0;34m(jaxpr, sdat_in_shardings, sdat_out_shardings, resource_env, donated_invars, name, keep_unused, inline, lowering_parameters)\u001b[0m\n\u001b[1;32m 1297\u001b[0m lowering_parameters=lowering_parameters)\n\u001b[1;32m 1298\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1299\u001b[0;31m return pxla.lower_sharding_computation(\n\u001b[0m\u001b[1;32m 1300\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapi_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_shardings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1301\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdonated_invars\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36mlower_sharding_computation\u001b[0;34m(fun_or_jaxpr, api_name, fun_name, in_shardings, out_shardings, donated_invars, global_in_avals, keep_unused, inline, devices_from_context, lowering_parameters)\u001b[0m\n\u001b[1;32m 2029\u001b[0m \u001b[0msemantic_out_shardings\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSemanticallyEqualShardings\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_shardings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2030\u001b[0m (module, keepalive, host_callbacks, unordered_effects, ordered_effects,\n\u001b[0;32m-> 2031\u001b[0;31m \u001b[0mnreps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape_poly_state\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_cached_lowering_to_hlo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2032\u001b[0m \u001b[0mclosed_jaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapi_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msemantic_in_shardings\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2033\u001b[0m \u001b[0msemantic_out_shardings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mda_object\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36m_cached_lowering_to_hlo\u001b[0;34m(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, da_object, donated_invars, name_stack, all_default_mem_kind, lowering_parameters)\u001b[0m\n\u001b[1;32m 1830\u001b[0m \u001b[0;34m\"Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time} sec\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1831\u001b[0m fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):\n\u001b[0;32m-> 1832\u001b[0;31m lowering_result = mlir.lower_jaxpr_to_module(\n\u001b[0m\u001b[1;32m 1833\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1834\u001b[0m \u001b[0mclosed_jaxpr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_module\u001b[0;34m(module_name, jaxpr, ordered_effects, backend_or_name, platforms, axis_context, name_stack, donated_args, replicated_args, arg_shardings, result_shardings, arg_names, result_names, num_replicas, num_partitions, all_default_mem_kind, lowering_parameters)\u001b[0m\n\u001b[1;32m 804\u001b[0m \u001b[0mattrs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"mhlo.num_partitions\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi32_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_partitions\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 805\u001b[0m \u001b[0mreplace_tokens_with_dummy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlowering_parameters\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplace_tokens_with_dummy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 806\u001b[0;31m lower_jaxpr_to_fun(\n\u001b[0m\u001b[1;32m 807\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"main\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mordered_effects\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpublic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 808\u001b[0m \u001b[0mcreate_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreplace_tokens_with_dummy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_fun\u001b[0;34m(ctx, name, jaxpr, effects, create_tokens, public, replace_tokens_with_dummy, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, num_output_tokens, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds)\u001b[0m\n\u001b[1;32m 1211\u001b[0m \u001b[0mcallee_name_stack\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrap_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapi_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1212\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mir_constants\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanonicalize_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsts\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1213\u001b[0;31m out_vals, tokens_out = jaxpr_subcomp(\n\u001b[0m\u001b[1;32m 1214\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname_stack\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallee_name_stack\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokens_in\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1215\u001b[0m consts, *args, dim_var_values=dim_var_values)\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mjaxpr_subcomp\u001b[0;34m(ctx, jaxpr, tokens, consts, dim_var_values, *args)\u001b[0m\n\u001b[1;32m 1429\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatforms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1430\u001b[0m \u001b[0;31m# Classic, single-platform lowering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1431\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrule_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1432\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1433\u001b[0m ans = lower_multi_platform(rule_ctx, str(eqn), rules,\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mf_lowered\u001b[0;34m(ctx, *args, **params)\u001b[0m\n\u001b[1;32m 1626\u001b[0m \u001b[0;31m# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1628\u001b[0;31m out, tokens = jaxpr_subcomp(\n\u001b[0m\u001b[1;32m 1629\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_context\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokens_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ir_consts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1630\u001b[0m *map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values)\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mjaxpr_subcomp\u001b[0;34m(ctx, jaxpr, tokens, consts, dim_var_values, *args)\u001b[0m\n\u001b[1;32m 1429\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatforms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1430\u001b[0m \u001b[0;31m# Classic, single-platform lowering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1431\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrule_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1432\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1433\u001b[0m ans = lower_multi_platform(rule_ctx, str(eqn), rules,\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/lax/control_flow/loops.py\u001b[0m in \u001b[0;36m_while_lowering\u001b[0;34m(ctx, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts, *args)\u001b[0m\n\u001b[1;32m 1670\u001b[0m body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))\n\u001b[1;32m 1671\u001b[0m for x in body_jaxpr.consts]\n\u001b[0;32m-> 1672\u001b[0;31m new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,\n\u001b[0m\u001b[1;32m 1673\u001b[0m tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values)\n\u001b[1;32m 1674\u001b[0m \u001b[0mout_tokens\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtokens_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meff\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0meff\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbody_effects\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mjaxpr_subcomp\u001b[0;34m(ctx, jaxpr, tokens, consts, dim_var_values, *args)\u001b[0m\n\u001b[1;32m 1429\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatforms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1430\u001b[0m \u001b[0;31m# Classic, single-platform lowering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1431\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrule_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1432\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1433\u001b[0m ans = lower_multi_platform(rule_ctx, str(eqn), rules,\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mf_lowered\u001b[0;34m(ctx, *args, **params)\u001b[0m\n\u001b[1;32m 1626\u001b[0m \u001b[0;31m# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1627\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1628\u001b[0;31m out, tokens = jaxpr_subcomp(\n\u001b[0m\u001b[1;32m 1629\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_context\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtokens_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_ir_consts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1630\u001b[0m *map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values)\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mjaxpr_subcomp\u001b[0;34m(ctx, jaxpr, tokens, consts, dim_var_values, *args)\u001b[0m\n\u001b[1;32m 1429\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplatforms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1430\u001b[0m \u001b[0;31m# Classic, single-platform lowering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1431\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrule_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1432\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1433\u001b[0m ans = lower_multi_platform(rule_ctx, str(eqn), rules,\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/lax/control_flow/loops.py\u001b[0m in \u001b[0;36m_while_lowering\u001b[0;34m(ctx, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts, *args)\u001b[0m\n\u001b[1;32m 1632\u001b[0m \u001b[0mmlir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir_constants\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcanonicalize_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcond_jaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1633\u001b[0m ]\n\u001b[0;32m-> 1634\u001b[0;31m ((pred,),), _ = mlir.jaxpr_subcomp(\n\u001b[0m\u001b[1;32m 1635\u001b[0m \u001b[0mcond_ctx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1636\u001b[0m \u001b[0mcond_jaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/mlir.py\u001b[0m in \u001b[0;36mjaxpr_subcomp\u001b[0;34m(ctx, jaxpr, tokens, consts, dim_var_values, *args)\u001b[0m\n\u001b[1;32m 1313\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc_op\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1315\u001b[0;31m def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,\n\u001b[0m\u001b[1;32m 1316\u001b[0m \u001b[0mtokens\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTokenSet\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1317\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSequence\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValue\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"#@title Run Prediction\n",
"display_images = True #@param {type:\"boolean\"}\n",
"\n",
"import sys\n",
"import warnings\n",
"warnings.simplefilter(action='ignore', category=FutureWarning)\n",
"from Bio import BiopythonDeprecationWarning\n",
"warnings.simplefilter(action='ignore', category=BiopythonDeprecationWarning)\n",
"from pathlib import Path\n",
"from colabfold.download import download_alphafold_params, default_data_dir\n",
"from colabfold.utils import setup_logging\n",
"from colabfold.batch import get_queries, run, set_model_type\n",
"from colabfold.plot import plot_msa_v2\n",
"\n",
"import os\n",
"import numpy as np\n",
"try:\n",
" K80_chk = os.popen('nvidia-smi | grep \"Tesla K80\" | wc -l').read()\n",
"except:\n",
" K80_chk = \"0\"\n",
" pass\n",
"if \"1\" in K80_chk:\n",
" print(\"WARNING: found GPU Tesla K80: limited to total length < 1000\")\n",
" if \"TF_FORCE_UNIFIED_MEMORY\" in os.environ:\n",
" del os.environ[\"TF_FORCE_UNIFIED_MEMORY\"]\n",
" if \"XLA_PYTHON_CLIENT_MEM_FRACTION\" in os.environ:\n",
" del os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]\n",
"\n",
"from colabfold.colabfold import plot_protein\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# For some reason we need that to get pdbfixer to import\n",
"if use_amber and f\"/usr/local/lib/python{python_version}/site-packages/\" not in sys.path:\n",
" sys.path.insert(0, f\"/usr/local/lib/python{python_version}/site-packages/\")\n",
"\n",
"def input_features_callback(input_features):\n",
" if display_images:\n",
" plot_msa_v2(input_features)\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"def prediction_callback(protein_obj, length,\n",
" prediction_result, input_features, mode):\n",
" model_name, relaxed = mode\n",
" if not relaxed:\n",
" if display_images:\n",
" fig = plot_protein(protein_obj, Ls=length, dpi=150)\n",
" plt.show()\n",
" plt.close()\n",
"\n",
"result_dir = jobname\n",
"log_filename = os.path.join(jobname,\"log.txt\")\n",
"setup_logging(Path(log_filename))\n",
"\n",
"queries, is_complex = get_queries(queries_path)\n",
"model_type = set_model_type(is_complex, model_type)\n",
"\n",
"if \"multimer\" in model_type and max_msa is not None:\n",
" use_cluster_profile = False\n",
"else:\n",
" use_cluster_profile = True\n",
"\n",
"download_alphafold_params(model_type, Path(\".\"))\n",
"results = run(\n",
" queries=queries,\n",
" result_dir=result_dir,\n",
" use_templates=use_templates,\n",
" custom_template_path=custom_template_path,\n",
" num_relax=num_relax,\n",
" msa_mode=msa_mode,\n",
" model_type=model_type,\n",
" num_models=5,\n",
" num_recycles=num_recycles,\n",
" relax_max_iterations=relax_max_iterations,\n",
" recycle_early_stop_tolerance=recycle_early_stop_tolerance,\n",
" num_seeds=num_seeds,\n",
" use_dropout=use_dropout,\n",
" model_order=[1,2,3,4,5],\n",
" is_complex=is_complex,\n",
" data_dir=Path(\".\"),\n",
" keep_existing_results=False,\n",
" rank_by=\"auto\",\n",
" pair_mode=pair_mode,\n",
" pairing_strategy=pairing_strategy,\n",
" stop_at_score=float(100),\n",
" prediction_callback=prediction_callback,\n",
" dpi=dpi,\n",
" zip_results=False,\n",
" save_all=save_all,\n",
" max_msa=max_msa,\n",
" use_cluster_profile=use_cluster_profile,\n",
" input_features_callback=input_features_callback,\n",
" save_recycles=save_recycles,\n",
" user_agent=\"colabfold/google-colab-main\",\n",
")\n",
"results_zip = f\"{jobname}.result.zip\"\n",
"os.system(f\"zip -r {results_zip} {jobname}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "KK7X9T44pWb7"
},
"outputs": [],
"source": [
"#@title Display 3D structure {run: \"auto\"}\n",
"import py3Dmol\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"from colabfold.colabfold import plot_plddt_legend\n",
"from colabfold.colabfold import pymol_color_list, alphabet_list\n",
"rank_num = 1 #@param [\"1\", \"2\", \"3\", \"4\", \"5\"] {type:\"raw\"}\n",
"color = \"lDDT\" #@param [\"chain\", \"lDDT\", \"rainbow\"]\n",
"show_sidechains = False #@param {type:\"boolean\"}\n",
"show_mainchains = False #@param {type:\"boolean\"}\n",
"\n",
"tag = results[\"rank\"][0][rank_num - 1]\n",
"jobname_prefix = \".custom\" if msa_mode == \"custom\" else \"\"\n",
"pdb_filename = f\"{jobname}/{jobname}{jobname_prefix}_unrelaxed_{tag}.pdb\"\n",
"pdb_file = glob.glob(pdb_filename)\n",
"\n",
"def show_pdb(rank_num=1, show_sidechains=False, show_mainchains=False, color=\"lDDT\"):\n",
" model_name = f\"rank_{rank_num}\"\n",
" view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js',)\n",
" view.addModel(open(pdb_file[0],'r').read(),'pdb')\n",
"\n",
" if color == \"lDDT\":\n",
" view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})\n",
" elif color == \"rainbow\":\n",
" view.setStyle({'cartoon': {'color':'spectrum'}})\n",
" elif color == \"chain\":\n",
" chains = len(queries[0][1]) + 1 if is_complex else 1\n",
" for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):\n",
" view.setStyle({'chain':chain},{'cartoon': {'color':color}})\n",
"\n",
" if show_sidechains:\n",
" BB = ['C','O','N']\n",
" view.addStyle({'and':[{'resn':[\"GLY\",\"PRO\"],'invert':True},{'atom':BB,'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n",
" {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n",
" {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
" if show_mainchains:\n",
" BB = ['C','O','N','CA']\n",
" view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",
"\n",
" view.zoomTo()\n",
" return view\n",
"\n",
"show_pdb(rank_num, show_sidechains, show_mainchains, color).show()\n",
"if color == \"lDDT\":\n",
" plot_plddt_legend().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "11l8k--10q0C"
},
"outputs": [],
"source": [
"#@title Plots {run: \"auto\"}\n",
"from IPython.display import display, HTML\n",
"import base64\n",
"from html import escape\n",
"\n",
"# see: https://stackoverflow.com/a/53688522\n",
"def image_to_data_url(filename):\n",
" ext = filename.split('.')[-1]\n",
" prefix = f'data:image/{ext};base64,'\n",
" with open(filename, 'rb') as f:\n",
" img = f.read()\n",
" return prefix + base64.b64encode(img).decode('utf-8')\n",
"\n",
"pae = image_to_data_url(os.path.join(jobname,f\"{jobname}{jobname_prefix}_pae.png\"))\n",
"cov = image_to_data_url(os.path.join(jobname,f\"{jobname}{jobname_prefix}_coverage.png\"))\n",
"plddt = image_to_data_url(os.path.join(jobname,f\"{jobname}{jobname_prefix}_plddt.png\"))\n",
"display(HTML(f\"\"\"\n",
"\n",
"
\n",
"
Plots for {escape(jobname)}
\n",
" \n",
" \n",
" \n",
"
\n",
"\"\"\"))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "33g5IIegij5R"
},
"outputs": [],
"source": [
"#@title Package and download results\n",
"#@markdown If you are having issues downloading the result archive, try disabling your adblocker and run this cell again. If that fails click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
"\n",
"if msa_mode == \"custom\":\n",
" print(\"Don't forget to cite your custom MSA generation method.\")\n",
"\n",
"files.download(f\"{jobname}.result.zip\")\n",
"\n",
"if save_to_google_drive == True and drive:\n",
" uploaded = drive.CreateFile({'title': f\"{jobname}.result.zip\"})\n",
" uploaded.SetContentFile(f\"{jobname}.result.zip\")\n",
" uploaded.Upload()\n",
" print(f\"Uploaded {jobname}.result.zip to Google Drive with ID {uploaded.get('id')}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UGUBLzB3C6WN",
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"# Instructions \n",
"**Quick start**\n",
"1. Paste your protein sequence(s) in the input field.\n",
"2. Press \"Runtime\" -> \"Run all\".\n",
"3. The pipeline consists of 5 steps. The currently running step is indicated by a circle with a stop sign next to it.\n",
"\n",
"**Result zip file contents**\n",
"\n",
"1. PDB formatted structures sorted by avg. pLDDT and complexes are sorted by pTMscore. (unrelaxed and relaxed if `use_amber` is enabled).\n",
"2. Plots of the model quality.\n",
"3. Plots of the MSA coverage.\n",
"4. Parameter log file.\n",
"5. A3M formatted input MSA.\n",
"6. A `predicted_aligned_error_v1.json` using [AlphaFold-DB's format](https://alphafold.ebi.ac.uk/faq#faq-7) and a `scores.json` for each model which contains an array (list of lists) for PAE, a list with the average pLDDT and the pTMscore.\n",
"7. BibTeX file with citations for all used tools and databases.\n",
"\n",
"At the end of the job a download modal box will pop up with a `jobname.result.zip` file. Additionally, if the `save_to_google_drive` option was selected, the `jobname.result.zip` will be uploaded to your Google Drive.\n",
"\n",
"**MSA generation for complexes**\n",
"\n",
"For the complex prediction we use unpaired and paired MSAs. Unpaired MSA is generated the same way as for the protein structures prediction by searching the UniRef100 and environmental sequences three iterations each.\n",
"\n",
"The paired MSA is generated by searching the UniRef100 database and pairing the best hits sharing the same NCBI taxonomic identifier (=species or sub-species). We only pair sequences if all of the query sequences are present for the respective taxonomic identifier.\n",
"\n",
"**Using a custom MSA as input**\n",
"\n",
"To predict the structure with a custom MSA (A3M formatted): (1) Change the `msa_mode`: to \"custom\", (2) Wait for an upload box to appear at the end of the \"MSA options ...\" box. Upload your A3M. The first fasta entry of the A3M must be the query sequence without gaps.\n",
"\n",
"It is also possilbe to proide custom MSAs for complex predictions. Read more about the format [here](https://github.com/sokrypton/ColabFold/issues/76).\n",
"\n",
"As an alternative for MSA generation the [HHblits Toolkit server](https://toolkit.tuebingen.mpg.de/tools/hhblits) can be used. After submitting your query, click \"Query Template MSA\" -> \"Download Full A3M\". Download the A3M file and upload it in this notebook.\n",
"\n",
"**PDB100** \n",
"\n",
"As of 23/06/08, we have transitioned from using the PDB70 to a 100% clustered PDB, the PDB100. The construction methodology of PDB100 differs from that of PDB70.\n",
"\n",
"The PDB70 was constructed by running each PDB70 representative sequence through [HHblits](https://github.com/soedinglab/hh-suite) against the [Uniclust30](https://uniclust.mmseqs.com/). On the other hand, the PDB100 is built by searching each PDB100 representative structure with [Foldseek](https://github.com/steineggerlab/foldseek) against the [AlphaFold Database](https://alphafold.ebi.ac.uk).\n",
"\n",
"To maintain compatibility with older Notebook versions and local installations, the generated files and API responses will continue to be named \"PDB70\", even though we're now using the PDB100.\n",
"\n",
"**Using custom templates** \n",
"\n",
"To predict the structure with a custom template (PDB or mmCIF formatted): (1) change the `template_mode` to \"custom\" in the execute cell and (2) wait for an upload box to appear at the end of the \"Input Protein\" box. Select and upload your templates (multiple choices are possible).\n",
"\n",
"* Templates must follow the four letter PDB naming with lower case letters.\n",
"\n",
"* Templates in mmCIF format must contain `_entity_poly_seq`. An error is thrown if this field is not present. The field `_pdbx_audit_revision_history.revision_date` is automatically generated if it is not present.\n",
"\n",
"* Templates in PDB format are automatically converted to the mmCIF format. `_entity_poly_seq` and `_pdbx_audit_revision_history.revision_date` are automatically generated.\n",
"\n",
"If you encounter problems, please report them to this [issue](https://github.com/sokrypton/ColabFold/issues/177).\n",
"\n",
"**Comparison to the full AlphaFold2 and AlphaFold2 Colab**\n",
"\n",
"This notebook replaces the homology detection and MSA pairing of AlphaFold2 with MMseqs2. For a comparison against the [AlphaFold2 Colab](https://colab.research.google.com/github/deepmind/alphafold/blob/main/notebooks/AlphaFold.ipynb) and the full [AlphaFold2](https://github.com/deepmind/alphafold) system read our [paper](https://www.nature.com/articles/s41592-022-01488-1).\n",
"\n",
"**Troubleshooting**\n",
"* Check that the runtime type is set to GPU at \"Runtime\" -> \"Change runtime type\".\n",
"* Try to restart the session \"Runtime\" -> \"Factory reset runtime\".\n",
"* Check your input sequence.\n",
"\n",
"**Known issues**\n",
"* Google Colab assigns different types of GPUs with varying amount of memory. Some might not have enough memory to predict the structure for a long sequence.\n",
"* Your browser can block the pop-up for downloading the result file. You can choose the `save_to_google_drive` option to upload to Google Drive instead or manually download the result file: Click on the little folder icon to the left, navigate to file: `jobname.result.zip`, right-click and select \\\"Download\\\" (see [screenshot](https://pbs.twimg.com/media/E6wRW2lWUAEOuoe?format=jpg&name=small)).\n",
"\n",
"**Limitations**\n",
"* Computing resources: Our MMseqs2 API can handle ~20-50k requests per day.\n",
"* MSAs: MMseqs2 is very precise and sensitive but might find less hits compared to HHblits/HMMer searched against BFD or MGnify.\n",
"* We recommend to additionally use the full [AlphaFold2 pipeline](https://github.com/deepmind/alphafold).\n",
"\n",
"**Description of the plots**\n",
"* **Number of sequences per position** - We want to see at least 30 sequences per position, for best performance, ideally 100 sequences.\n",
"* **Predicted lDDT per position** - model confidence (out of 100) at each position. The higher the better.\n",
"* **Predicted Alignment Error** - For homooligomers, this could be a useful metric to assess how confident the model is about the interface. The lower the better.\n",
"\n",
"**Bugs**\n",
"- If you encounter any bugs, please report the issue to https://github.com/sokrypton/ColabFold/issues\n",
"\n",
"**License**\n",
"\n",
"The source code of ColabFold is licensed under [MIT](https://raw.githubusercontent.com/sokrypton/ColabFold/main/LICENSE). Additionally, this notebook uses the AlphaFold2 source code and its parameters licensed under [Apache 2.0](https://raw.githubusercontent.com/deepmind/alphafold/main/LICENSE) and [CC BY 4.0](https://creativecommons.org/licenses/by-sa/4.0/) respectively. Read more about the AlphaFold license [here](https://github.com/deepmind/alphafold).\n",
"\n",
"**Acknowledgments**\n",
"- We thank the AlphaFold team for developing an excellent model and open sourcing the software.\n",
"\n",
"- [KOBIC](https://kobic.re.kr) and [Söding Lab](https://www.mpinat.mpg.de/soeding) for providing the computational resources for the MMseqs2 MSA server.\n",
"\n",
"- Richard Evans for helping to benchmark the ColabFold's Alphafold-multimer support.\n",
"\n",
"- [David Koes](https://github.com/dkoes) for his awesome [py3Dmol](https://3dmol.csb.pitt.edu/) plugin, without whom these notebooks would be quite boring!\n",
"\n",
"- Do-Yoon Kim for creating the ColabFold logo.\n",
"\n",
"- A colab by Sergey Ovchinnikov ([@sokrypton](https://twitter.com/sokrypton)), Milot Mirdita ([@milot_mirdita](https://twitter.com/milot_mirdita)) and Martin Steinegger ([@thesteinegger](https://twitter.com/thesteinegger)).\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"include_colab_link": true,
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}