# Subsample taxa in phylogeny

## Overview

**Goal**: Generate a smaller group of taxa from the original 10,575 genomes in the phylogeny, which can then be used for phylogenetic analyses using more robust methods. Because the original phylogeny is already obtained, the subsampling process will take advantage of the distance and diversity indicated by the phylogeny.

Two strategies are used here:

I. Use **prototype selection** based on the phylogenetic distance (sum of lengths of branches connecting each pair of tips).

- Pro: Maximizes the sum of phylogenetic distance (hence the largest diversity).
- Con: Favors long branches, and multiple closely related taxa at the far end of a long branch.

II. Select clades with the smallest relative evolutionary divergence (**RED**) ([Parks, et al., 2018](https://www.nature.com/articles/nbt.4229)). Then select one taxon within each clade.

- Pro: Ensures that the basal phylogeny is sufficiently represented.
- Con: Misses long branches, such as the single-taxon phyla.

With strategy II, the following criteria are applied sequentially to further select one taxon within each clade:

1. Contains the most marker genes.
2. Contamination level is the lowest.
3. DNA quality score is the highest.

The process terminates whenever there is one candidate left. If not after all three criteria are applied, the program will randomly choose one taxon.

The following two rules are optional:

1. Must be represented in the r-protein tree.
2. Must have a standard Latin species name.

Finally, for specific applications, manual curation is involved.


## Dependencies

In [1]:
from random import seed, choice

In [2]:
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from skbio.tree import TreeNode
from skbio.stats.distance import DistanceMatrix

In [3]:
seed(42)

In [4]:
%matplotlib inline

## Parameters

Number of taxa to keep

In [5]:
ntaxa = 92

## Input files

Original, full-scale phylogenetic tree (10,575 taxa)

In [6]:
tree_fp = 'astral.cons.nid.nwk'

In [7]:
tree = TreeNode.read('astral.cons.nid.nwk')
tree.count(tips=True)

10575

Branch support values

In [8]:
supports_fp = 'astral.supports.tsv'

In [9]:
dfs = pd.read_table(supports_fp, index_col=0)
dfs.head(3)

Unnamed: 0_level_0,EN,LPP,QT
#node,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
N2,196.0,0.998406,0.450953
N3,196.0,0.998406,0.450953
N4,124.0,0.999993,0.53512


Ignore low-support branches (EN <= 5 or LPP <= 0.5 (abbr.: e5p50)).

In [10]:
en_th = 5
lpp_th = 0.5

Genome metadata

In [11]:
genomes_fp = '../../genomes/metadata.tsv.xz'

In [12]:
dfg = pd.read_table(genomes_fp, index_col=0)
dfg.columns

Index(['asm_name', 'assembly_accession', 'bioproject', 'biosample',
       'wgs_master', 'seq_rel_date', 'submitter', 'ftp_path', 'img_id',
       'gtdb_id', 'scope', 'assembly_level', 'genome_rep', 'refseq_category',
       'release_type', 'taxid', 'species_taxid', 'organism_name',
       'infraspecific_name', 'isolate', 'superkingdom', 'phylum', 'class',
       'order', 'family', 'genus', 'species', 'classified', 'unique_name',
       'lv1_group', 'lv2_group', 'score_faa', 'score_fna', 'score_rrna',
       'score_trna', 'total_length', 'contigs', 'gc', 'n50', 'l50', 'proteins',
       'protein_length', 'coding_density', 'completeness', 'contamination',
       'strain_heterogeneity', 'markers', '5s_rrna', '16s_rrna', '23s_rrna',
       'trnas', 'draft_quality'],
      dtype='object')

Taxa in the r-protein tree.

In [13]:
rpls_fp = 'rpls.txt'

In [14]:
with open(rpls_fp, 'r') as f:
    rpls = set(f.read().splitlines())
len(rpls)

9814

In [15]:
melai_fp = 'dte/Melainabacteria.txt'

In [16]:
with open(melai_fp, 'r') as f:
    melai = set(f.read().splitlines())
len(melai)

28

## Helper functions

Prototype selection

In [17]:
def prototype_selection_destructive_maxdist(dm, num_prototypes, seedset=None):
    """Prototype selection function (minified)."""
    numRemain = len(dm.ids)
    currDists = dm.data.sum(axis=1)
    maxVal = currDists.max()
    if seedset is not None:
        for e in seedset:
            currDists[dm.index(e)] = maxVal * 2
    minElmIdx = currDists.argmin()
    currDists[minElmIdx], numRemain = np.infty, numRemain - 1
    while (numRemain > num_prototypes):
        currDists -= dm.data[minElmIdx]
        minElmIdx = currDists.argmin()
        currDists[minElmIdx], numRemain = np.infty, numRemain - 1
    return [dm.ids[idx]
            for idx, dist in enumerate(currDists)
            if dist != np.infty]

RED calculation

In [18]:
def calc_brlen_metrics(tree):
    """Calculate branch length-related metrics.

    Parameters
    ----------
    tree : skbio.TreeNode

    Notes
    -----
    The following metrics are calculated:
    
        - height: Sum of branch lengths from the root to the node.

        - depths: Sums of branch lengths from all descendants to current node.

        - red: Relative evolutionary divergence (RED), introduced by Parks,
          et al., 2018, Nat Biotechnol.

              RED = p + (d / u) * (1 - p)

          where p = RED of parent, d = length, u = mean depth of parent
    
    Metrics will be appended to each node of tree in place.
    """
    # calculate depths
    for node in tree.postorder(include_self=True):
        if node.name is None:
            raise ValueError('Error: Found an unnamed node.')
        if node.length is None:
            node.length = 0.0
        if node.is_tip():
            node.depths = [0.0]
            node.taxa = [node.name]
        else:
            node.depths = [
                y + x.length for x in node.children for y in x.depths]
            node.taxa = sorted(set().union(*[x.taxa for x in node.children]))

    # calculate heights and REDs
    for node in tree.preorder(include_self=True):
        if node.is_root():
            node.height = 0.0
            node.red = 0.0
        else:
            node.height = node.parent.height + node.length
            if node.is_tip():
                node.red = 1.0
            else:
                node.red = node.parent.red + node.length \
                    / (node.length + sum(node.depths) / len(node.depths)) \
                    * (1 - node.parent.red)

Latin species name checker

In [19]:
def is_latin(str):
    if str == '':
        return False
    elif str.count(' ') != 1:
        return False
    str_ = str.replace(' ', '')
    if not str_.istitle():
        return False
    elif not str_.isalpha():
        return False
    return True

## Analysis

Optional: Filter tree to keep taxa with Latin species name.

In [20]:
g2species = dfg['species'].to_dict()

In [21]:
%%script false
to_keep = [x for x in tree.subset() if is_latin(g2species[x])]
tree = tree.shear(to_keep)
print('Tree has %d taxa with Latin species names.' % tree.count(tips=True))

Optional: Filter tree to keep taxa represented in the r-protein tree.

In [22]:
to_keep = [x for x in tree.subset() if x in rpls]
tree = tree.shear(to_keep)
print('Tree has %d taxa represented in the r-protein tree.' % tree.count(tips=True))

Tree has 9814 taxa represented in the r-protein tree.


### Solution I: Prototype selection based on phylogenetic distance.

In [23]:
%%script false
print('Tree has %d taxa.' % tree.count(tips=True))

print('Calculating tip-to-tip distances...')
dm = tree.tip_tip_distances()
print('Sum of distances: %d.' % np.tril(dm.data).sum())

print('Performing prototype selection...')
prototypes = prototype_selection_destructive_maxdist(dm, ntaxa)
print('Downsampled to %d taxa.' % len(prototypes))
print('Sum of distances: %d.' % np.tril(dm.filter(prototypes).data).sum())

In [24]:
%%script false
tout = tree.shear(prototypes)
tout.write('proto.%d.nwk' % ntaxa)

with open('proto.%d.txt' % ntaxa, 'w') as f:
    for g in sorted(prototypes):
        print(g, file=f)

gs = prototypes

### Solution II: Choose clades by minimizing RED.

Calculate branch length-related metrics.

In [25]:
calc_brlen_metrics(tree)

In [26]:
data = {}
for node in tree.postorder(include_self=True):
    data[node.name] = {
        'parent': None if node.is_root() else node.parent.name,
        'taxa': node.taxa,
        'red': node.red,
    }
len(data)

19627

Filter internal nodes by branch support value.

In [27]:
valid = set(dfs.query('EN > %d and LPP > %f' % (en_th, lpp_th)).index.tolist())
len(valid)

10027

A helper function to remove any ancestral node of current node from the pool (so that the chosen clades do not overlap).

In [28]:
def remove_ancestor(nid, chosen, data):
    cid = nid
    while True:
        if cid in chosen:
            chosen.remove(cid)
            return cid
        else:
            cid = data[cid]['parent']
            if cid is None:
                return None

Iterate from the low end of the RED list, adding clades to the pool, until the desired number of taxa is reached.

In [29]:
chosen_clades = set()

In [30]:
for nid in sorted(data, key=lambda x: data[x]['red']):
    if nid.startswith('G') or nid in valid:
        remove_ancestor(nid, chosen_clades, data)
        chosen_clades.add(nid)
        if len(chosen_clades) == ntaxa:
            break
len(chosen_clades)

92

Within each clade, select one taxon by sequentially applying the following criteria:
1. Contains the most marker genes.
2. Contamination level is the lowest.
3. DNA quality score is the highest.
4. Randomly choose one.

In [31]:
chosen_taxa = {}

Criterium 1: Most marker genes.

In [32]:
g2markers = dfg['markers'].to_dict()

In [33]:
def most_markers(gs, g2markers):
    max_gs = []
    max_markers = 0
    for g in sorted(gs, key=lambda x: g2markers[x], reverse=True):
        if max_markers == 0:
            max_markers = g2markers[g]
            max_gs.append(g)
        elif max_markers == g2markers[g]:
            max_gs.append(g)
        else:
            break
    return max_gs, max_markers

Criterium 2: Lowest contamination.

In [34]:
g2contam = dfg['contamination'].to_dict()

In [35]:
def least_contaminated(gs, g2contam):
    min_gs = []
    min_contam = None
    for g in sorted(gs, key=lambda x: g2contam[x]):
        if min_contam is None:
            min_contam = g2contam[g]
            min_gs.append(g)
        elif min_contam == g2contam[g]:
            min_gs.append(g)
        else:
            break
    return min_gs, min_contam

Criterium 3: Highest DNA quality score.

In [36]:
g2dnaqty = dfg['score_fna'].to_dict()

In [37]:
def best_dna(gs, g2dnaqty):
    max_gs = []
    max_dnaqty = 0
    for g in sorted(gs, key=lambda x: g2dnaqty[x], reverse=True):
        if max_dnaqty == 0:
            max_dnaqty = g2dnaqty[g]
            max_gs.append(g)
        elif max_dnaqty == g2dnaqty[g]:
            max_gs.append(g)
        else:
            break
    return max_gs, max_dnaqty

Perform taxon selection.

In [38]:
chosen_taxa = {}

In [39]:
for nid in chosen_clades:
    gs = data[nid]['taxa']

    # optional: r-protein
#     gs = rpls.intersection(gs)
#     if len(gs) == 0:
#         raise ValueError('%s: No taxon is represented in the r-protein tree.' % nid)

    # optional: Latin name
#     gs_ = [x for x in gs if is_latin(g2species[x])]
#     if len(gs_) == 0:
#         print('%s: No taxon has a Latin species name.' % nid)
#     else:
#         gs = gs_

    # maximize marker count
    gs, max_markers = most_markers(gs, g2markers)

    if len(gs) > 1:
        # minimize contamination
        gs, min_contam = least_contaminated(gs, g2contam)

    if len(gs) > 1:
        # maximize DNA quality
        gs, max_dnaqty = best_dna(gs, g2dnaqty)

    if len(gs) > 1:
        # random choice
        print('Clade %s: Equally good: %s.' % (nid, ', '.join(gs)))
        g = choice(gs)
    else:
        g = max(gs)

    chosen_taxa[nid] = g

Clade N250: Equally good: G000007005, G900079115.
Clade N1637: Equally good: G000283655, G001305595.


In [40]:
for nid, g in sorted(chosen_taxa.items(), key=lambda x: g2species[x[1]]):
    print('%s - %s: %s' % (nid, g, g2species[g]))

N741 - G001768645: Actinobacteria bacterium RBG_13_55_18
N1122 - G001399675: Alicyclobacillus ferrooxydans
N118 - G002011215: Archaeoglobus sp. JdFR-24
N966 - G000092245: Arcobacter nitrofigilis
N1637 - G000283655: Azospirillum lipoferum
N1334 - G000311725: Bacillus massiliosenegalensis
N1898 - G001769855: Bdellovibrionales bacterium GWA2_49_15
N1304 - G000281175: Caldilinea aerophila
N3441 - G001886815: Caldithrix abyssi
N94 - G002010385: Candidatus Acetothermia bacterium JdFR-52
N242 - G001593935: Candidatus Bathyarchaeota archaeon B26-2
N198 - G001774395: Candidatus Berkelbacteria bacterium RIFCSPLOWO2_01_FULL_50_28
N109 - G000270325: Candidatus Caldiarchaeum subterraneum
N190 - G001777435: Candidatus Daviesbacteria bacterium RIFCSPLOWO2_02_FULL_38_18
N1361 - G001777945: Candidatus Edwardsbacteria bacterium RIFOXYD12_FULL_50_11
N275 - G001779195: Candidatus Gottesmanbacteria bacterium RIFCSPHIGHO2_02_FULL_39_11
N15 - G001940645: Candidatus Heimdallarchaeota archaeon LC_3
N1607 - G00

In [41]:
%%script false
for g in data['N2193']['taxa']:
    print('%s: %s' % (g, dfg.loc['G000992445']['species']))

Output results.

In [42]:
with open('red.%d.txt' % ntaxa, 'w') as f:
    for nid in sorted(chosen_taxa, key=lambda x: int(x[1:])):
        f.write('%s\t%s\n' % (chosen_taxa[nid], nid))

In [43]:
gs = sorted(chosen_taxa.values())

In [44]:
tout = tree.shear(gs)
tout.write('red.%d.nwk' % ntaxa)

'red.92.nwk'

In [45]:
%%script false
print('Sum of distances: %d.' % np.tril(dm.filter(gs).data).sum())

### Explore results.

In [46]:
def check_keyword(gs, df, col, word):
    res = []
    d_ = df[col].dropna().to_dict()
    for g in gs:
        if g in d_ and word in d_[g]:
            res.append('%s: %s' % (g, word))
    return res

In [47]:
check_keyword(gs, dfg, 'phylum', 'Cyanobacteria')

[]

In [48]:
melai.intersection(gs)

{'G001858525', 'G001899315'}

In [49]:
len(check_keyword(gs, dfg, 'lv1_group', 'Archaea'))

12

In [50]:
len(check_keyword(gs, dfg, 'lv1_group', 'CPR'))

15

In [51]:
for lv2 in dfg['lv2_group'].unique():
    print('%s: %d' % (lv2, len(check_keyword(gs, dfg, 'lv2_group', lv2))))

Firmicutes: 16
Euryarchaeota: 5
Actinobacteria: 4
Proteobacteria: 9
FCB: 4
Crenarchaeota: 1
Chlamydiae: 0
Bacteria: 17
Cyanobacteria: 0
Spirochaetes: 3
Terrabacteria: 3
Chloroflexi: 3
Bacteroidetes: 1
TACK: 3
PVC: 5
Archaea: 1
CPR: 5
Parcubacteria: 6
DPANN: 1
Microgenomates: 4
Asgard: 1
