markov3_io.py

from __future__ import print_function

"""Module for reading and writing unigram and bigram
data using tab-separated-value (TSV) format."""

# We choose to put these functions together in a single module
# because changing one immediately requires changing the other
# to match.  If the functions are dispersed into individual
# programs, we would need to make the same modification in all
# programs instead of just in this one module.

def write_grams(path, m1, m2):
    with _open_gram_file(path, "w") as f:
        for prefix, suffix in m1.items():
            print("%s\t%s" % (prefix, '\t'.join(suffix)), file=f)
        print(file=f)
        for prefix, suffix in m2.items():
            print("%s\t%s\t%s" % (prefix[0], prefix[1], '\t'.join(suffix)),
                  file=f)

def read_grams(path):
    # We throw an IOError when we encounter data that we cannot
    # handle because it's the same type of exception thrown when
    # a file is missing.  Any callers can catch both errors with
    # a single except clause.  This is nice since there is no
    # need to distinguish the error types (no file, bad file)
    # since the end effect is the same: there's no data available.
    f = _open_gram_file(path)
    line_number = 0
    m1 = {}
    for line in f:
        line_number += 1
        values = line.strip().split('\t')
        if len(values) == 1:
            break
        elif len(values) > 1:
            m1[values[0]] = values[1:]
        else:
            raise IOError("unexpected m1 data at line %d" %
                        line_number)
    m2 = {}
    for line in f:
        line_number += 1
        values = line.strip().split('\t')
        if len(values) == 1:
            break
        elif len(values) > 2:
            prefix = (values[0], values[1])
            m2[prefix] = values[2:]
        else:
            raise IOError("unexpected m2 data at line %d" %
                        line_number)
    f.close()
    return m1, m2

def _open_gram_file(path, mode="r"):
    # Return the file object to the data file associated with
    # the text at the given path
    import os.path
    root, ext = os.path.splitext(path)
    tsv_name = root + ".tsv"
    return open(tsv_name, mode)

#
# Shift is placed in this file because it is often used by
# calling scripts that use either read_grams or write_grams
#
def shift(prefix, word):
    return prefix[1:] + (word,)