Source code for ocd.shots_db

# Observing condition decision tool: monitor conditions and plan HETDEX
# observations
# Copyright (C) 2017, 2018  "The HETDEX collaboration"
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
'''This module contains:

* SQLite3 database to store and manage the list of shots. The database
  interface is done using `peewee <http://docs.peewee-orm.com/en/latest/>`_
* functions load and dump the database content
* an interface with a MySQL database to obtain the shot number

'''
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import contextlib
import datetime
import os

from astropy.io import ascii
from astropy.table import Table
import numpy as np
import peewee
from pyhetdex.tools import db_helpers
import pymysql
import six

from ocd import errors
from ocd import utils


MEMORY_DB = ':memory:'


# interface

# create the database but do not initialize it
database = peewee.SqliteDatabase(None)
'Database where the information will be stored'


connect = db_helpers.SQLiteConnector(database, keep_open=MEMORY_DB)
'''Context manager to open and close the database connection, unless it's in
memory'''


SHOT_NAMES = ['shotid', 'ra', 'dec', 'track', 'priority', 'n_obs', 'forced_az']
'''Name of the columns in the input and output shot file that define a shot and
its priority, number of observations, ...'''
METADATA_NAMES = ['QIDX', 'QPROG', 'QOBJECT', 'QIFU', 'QRA', 'QDEC',
                  'QEQUINOX', 'QPMRA', 'QPMDEC', 'QEPOCH']
'''Name of the columns in the input shot file that contains metadata to pass to
TCS'''
COLUMN_NAMES = SHOT_NAMES + METADATA_NAMES
'''Name of the mandatory columns of the input shot file'''


[docs]class BaseModel(peewee.Model): '''Base Model class for the tables representing shots''' shotid = peewee.CharField(unique=True, help_text='ID of the shot') class Meta: database = database
[docs]class Shots(BaseModel): '''Table representing the shots''' ra = peewee.FloatField(help_text='RA in decimal hours') dec = peewee.FloatField(help_text='DEC in decimal degrees') _track_constraints = [peewee.Check('track IN (0, 1, 2)'), ] track = peewee.IntegerField(constraints=_track_constraints) priority = peewee.IntegerField(constraints=[peewee.Check('priority >= 1'), ]) n_obs = peewee.IntegerField(constraints=[peewee.Check('n_obs >= 0'), ]) forced_az = peewee.FloatField()
[docs]class ShotMetadata(BaseModel): '''Shot metadata. For each entry in :class:`Shots` there must be an entry here''' QIDX = peewee.IntegerField() QPROG = peewee.CharField() QOBJECT = peewee.CharField() QIFU = peewee.CharField() QRA = peewee.CharField() QDEC = peewee.CharField() QEQUINOX = peewee.FloatField() QPMRA = peewee.FloatField() QPMDEC = peewee.FloatField() QEPOCH = peewee.FloatField()
[docs]def init(conf): '''Initialize the database and the :class:`Shots`. Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the options ``database_name`` and ``drop_existing_tables`` of the ``[database]`` section. If ``database_name`` is ``:tmpfile:`` a temporary file is used ''' db_section = conf['database'] db_name = db_section['database_name'] if db_name == ':tmpfile:': db_name = os.path.join(utils.tmpdir(), 'shot_list.db') database.init(db_name) with connect(): _tables = [Shots, ShotMetadata] if db_section.getboolean('drop_existing_tables', fallback=False): database.drop_tables(_tables, safe=True) database.create_tables(_tables, safe=True)
[docs]def store_shot_file(conf): '''Read the file and add all the elements to the database. Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[shots]`` section: * ``shot_file``: name of the shot file; * ``clear_shot_table``: if found and true, clear the database table before inserting the new shot file; * ``update_existing_ids``: if a shotid already exists in the database update it with the values from the shot file ''' shot_sec = conf['shots'] if shot_sec.getboolean('clear_shot_table'): with connect(): Shots.drop_table() Shots.create_table() Shots.drop_table() Shots.create_table() # get the shot file and convert it to records shot_file = shot_sec['shot_file'] table = load_shot_file(shot_file) shot_records = table_to_records(table, columns=SHOT_NAMES) metadata_records = table_to_records(table, columns=SHOT_NAMES[:1]+METADATA_NAMES) # insert the records to the database # use bulk insert only if the database is empty with connect(): if Shots.select().count() == 0: bulk_insert(shot_records, metadata_records) else: update_existing_ids = shot_sec.getboolean('update_existing_ids') insert_or_update(shot_records, metadata_records, update_existing_ids)
[docs]def create_shot_file(conf): '''Create a shot file from the database. Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[shots]`` section: * ``out_shot_dir``: directory where to store the shot files; see :func:`load_shot_file` for output file description; * ``out_shot_file_template``: template for the output shot file name * ``all_shots``: if found and true, save all the shots, otherwise saves only the shots with non null ``n_obs``; ''' out_shot_dir = conf['shots']['out_shot_dir'] template_name = conf['shots']['out_shot_file_template'] default_name = 'ocd_shot_{}.list' keep_shot_files = conf['shots'].getint('keep_shot_files', fallback=-1) out_shot_file = utils.get_out_file(out_shot_dir, template_name, default_name, keep_shot_files) with connect(): q = Shots.select() if not conf['shots'].getboolean('all_shots', fallback=False): q = q.where(Shots.n_obs > 0) table = query_to_table(q) # write the table table.write(out_shot_file, format='ascii.commented_header', overwrite=True) return out_shot_file
[docs]def update_shot(shots_dict, success=True): '''Update the entry for the given shot id, decreasing the n_obs. .. todo:: This function is not in its final form. The current functionality is very limited and will be expanded. See :issue:`2056`, :issue:`2057` and :issue:`2080`. Parameters ---------- shots_dict : dict dictionary with the parameters used to run the shot. Relevant entries: * shotid (string): id of the shot * azimuth (float): commanded azimuth; if :attr:`Shots.forced_az` is negative, replace it with this value * track (int): commanded track; if :attr:`Shots.track` is 2, replace it with this value success : bool, optional whether the shot is successful or not Raises ------ ShotDoesNotExist if the shot id is not found ''' with connect(): try: row = Shots.get(Shots.shotid == shots_dict['shotid']) row.n_obs -= 1 if row.forced_az < 0: row.forced_az = shots_dict['azimuth'] if row.track == 2: row.track = shots_dict['track'] row.save() except Shots.DoesNotExist as e: six.raise_from(errors.ShotDoesNotExist(e), e) except peewee.IntegrityError as e: six.raise_from(errors.ShotIntegrityError(e), e)
[docs]def get_metadata(shotid): '''Get the metadata dictionary for the give shotid. Parameters ---------- shotid : string SHOTID to search Returns ------- metadata : dict dictionary of metadata Raises ------ ShotDoesNotExist if the ``shotid`` is not found in the tables ''' columns = [getattr(ShotMetadata, name) for name in METADATA_NAMES] with connect(): row = (ShotMetadata.select(*columns) .where(ShotMetadata.shotid == shotid) .dicts()) try: metadata = row.get() except ShotMetadata.DoesNotExist as e: six.raise_from(errors.ShotDoesNotExist(e), e) return metadata
# implementation
[docs]def bulk_insert(shot_records, metadata_records): '''Do a bulk insert of the shot and the metadata records. Parameters ---------- shot_records : iterable of dicts records corresponding to the :class:`Shots` table metadata_records : iterable of dicts records corresponding to the :class:`ShotMetadata` table ''' shot_records = list(shot_records) metadata_records = list(metadata_records) max_len = max(len(shot_records[0]), len(metadata_records[0])) # sqlite3 has a maximum number of variables that can replace when executing # a query. Make sure not to go above that number chunk_size = (db_helpers.SQLITE_MAX_VARIABLE_NUMBER // max_len) - 1 for i in range(0, len(shot_records), chunk_size): insert_query = Shots.insert_many(shot_records[i:i+chunk_size]) insert_query.execute() _records = metadata_records[i:i+chunk_size] insert_query = ShotMetadata.insert_many(_records) insert_query.execute()
[docs]def insert_or_update(shot_records, metadata_records, update_exising): '''Go through all the records and try to add them to the database. If a shot id already exists and ``update_exising`` is ``True``, use the record to update the database, otherwise skip the record Parameters ---------- shot_records : list of dicts records corresponding to the :class:`Shots` table metadata_records : iterable of dicts records corresponding to the :class:`ShotMetadata` table update_exising : bool whether existing shot ids must be updated by a corresponding record ''' for shot_rec, meta_rec in zip(shot_records, metadata_records): shot_row, created = Shots.get_or_create(shotid=shot_rec['shotid'], defaults=shot_rec) meta_row, _ = ShotMetadata.get_or_create(shotid=meta_rec['shotid'], defaults=meta_rec) if not created and update_exising: # update an existing entry for k, v in shot_rec.items(): if k != 'shotid': setattr(shot_row, k, v) shot_row.save() for k, v in meta_rec.items(): if k != 'shotid': setattr(meta_row, k, v) meta_row.save()
[docs]def load_shot_file(shot_file): '''Load the shot file. The file must contain at least the columns given by :data:`COLUMN_NAMES`: * ``shotid`` (string): the ID of the shot; must be unique; * ``ra`` (float): ra of the shot in decimal hours; * ``dec`` (float): dec of the shot in decimal degrees; * ``track`` (integer): must be one of 0, 1, 2; 0: force EAST track; 1: force WEST track; 2: let the software decide the optimal track; * ``priority`` (int): priority for the shot; 1 is the highest priority and decreases as the number increases; must be equal or larger than 1; * ``n_obs`` (int): number of times the shot must be observed; if not present defaults to 1; * ``forced_az`` (float): if -1, the scheduling algorithm decides the best azimuth to use, if positive forces the azimuth to the given value. * ``Q*`` (from :data:`METADATA_NAMES`): extra column that connect each HETDEX shot to the global HET queue The shot file can be in any format supported by `astropy ascii <http://docs.astropy.org/en/stable/io/ascii/index.html>`_. If file must an header with all the :data:`COLUMN_NAMES` columns. Parameters ---------- shot_file : string name of the file to load Returns ------- table : :class:`astropy.table.Table` instance relevant part of the shot file Raises ------ ShotColumnError if a mandatory column is missing ''' converters = {'shotid': [ascii.convert_numpy(np.str), ], 'QIFU': [ascii.convert_numpy(np.str), ], 'QEQUINOX': [ascii.convert_numpy(np.float), ], 'QEPOCH': [ascii.convert_numpy(np.float), ], } in_table = ascii.read(shot_file, converters=converters) # create a new table with only the relevant content try: table = [in_table[mc] for mc in COLUMN_NAMES] except KeyError as e: msg = "The mandatory column {col} is missing from the file '{fn}'" msg = msg.format(col=str(e), fn=shot_file) six.raise_from(errors.ShotColumnError(msg), e) return Table(table)
[docs]def table_to_records(table, columns=None): '''Generator that yields the rows of the input :class:`astropy.table.Table` to as dictionaries. Parameters ---------- table : :class:`astropy.table.Table` instance relevant part of the shot file columns : list of strings, optional if given, returns only the record for the given columns Yields ------ record : dictionary dictionary representation of a table row ''' if columns: table = table[columns] for row in table: record = dict(zip(row.colnames, row.as_void())) yield record
[docs]def query_to_table(query): '''Convert a query to an astropy Table Parameters ---------- query : :class:`peewee.SelectQuery` table query to convert all_shots : boolean if False, skip rows with no remaining observations Returns ------- table : :class:`astropy.table.Table` table representation of the query ''' rows = [] for row in query.dicts(): rows.append([row[k] for k in SHOT_NAMES]) table = Table(rows=rows, names=SHOT_NAMES) return table
# MySQL interface OBSNUM_TABLE_NAME = 'vl_obsnum' '''name of the table in the mysql database containing the observation number'''
[docs]@contextlib.contextmanager def mysql_connection(conf): '''Context manager yielding a mysql connection object. The connection is closed after the context returns. Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[database]`` section: * ``mysql_host``, ``mysql_port``: IP/host name and port of the MySQL server * ``mysql_database``: name of the database * ``mysql_user``, ``mysql_password``: name and password of the user that does the queries Returns ------- conn : :class:`pymysql.connections.Connection` MySQL connection ''' db_conf = conf['database'] conn = pymysql.connect(host=db_conf['mysql_host'], port=db_conf.getint('mysql_port'), database=db_conf['mysql_database'], user=db_conf['mysql_user'], password=db_conf['mysql_password']) try: yield conn finally: conn.close()
[docs]def get_obsnumber(conf): '''For the current UTC date and the "virus" instrument, get the highest ``obsnum`` from MySQL database, increase it by one, insert it back in the database and return it. If no ``obsnum`` is found, returns 1. Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[database]`` section: * same as :func:`mysql_connection` * ``mysql_update_obsnum``: if ``true``, the new ``obsnum`` is inserted in the database, otherwise the insertion is skipped. Returns ------- next_obsnum : int next observation number ''' db_conf = conf['database'] instrument = 'virus' # get the current ut day ut_now = datetime.datetime.utcnow() ut_day = ut_now.strftime('%Y%m%d') with mysql_connection(conf) as conn: cursor = conn.cursor() try: # get the highest obsnum for the ut_day query = "SELECT MAX(obsnum) FROM {} WHERE obsdate=%s and inst=%s" # unfortunately, the table name cannot be substituted like the # other arguments query = query.format(OBSNUM_TABLE_NAME) cursor.execute(query, args=(ut_day, instrument)) obsnum = cursor.fetchone()[0] if obsnum is None: # if nothing is in the database, return 1 next_obsnum = 1 else: # increase by one next_obsnum = obsnum + 1 max_obsnum = db_conf.getint('max_obsnum') if next_obsnum >= max_obsnum: msg = ('The observation number {n} is larger or equal than the' ' maximum allowed value {max_}') raise errors.ObsnumIntegrityError(msg.format(n=next_obsnum, max_=max_obsnum)) # insert into the database the new value if db_conf.getboolean('mysql_update_obsnum', fallback=True): query = ("INSERT INTO {} (obsdate, inst, obsnum) VALUES" " (%s, %s, %s)") query = query.format(OBSNUM_TABLE_NAME) cursor.execute(query, args=(ut_day, instrument, next_obsnum)) except Exception: # if a failure happens, rollback and re-raise conn.rollback() raise else: # otherwise commit changes conn.commit() finally: # and whatever happens, close the cursor cursor.close() return next_obsnum
[docs]def create_obsnum_table(conf): '''Create the table :attr:`OBSNUM_TABLE_NAME` Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[database]`` section: * same as :func:`mysql_connection` ''' query = """ CREATE TABLE {} (id smallint(5) UNSIGNED NOT NULL AUTO_INCREMENT, ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, obsdate DATE NOT NULL, inst VARCHAR(5) NOT NULL, obsnum mediumint NOT NULL, PRIMARY KEY (id), INDEX (obsdate) ); """ query = query.format(OBSNUM_TABLE_NAME) with mysql_connection(conf) as conn: cursor = conn.cursor() cursor.execute(query) conn.commit()
[docs]def fill_obsnum_table(conf, date=None): '''Add one line in the table :attr:`OBSNUM_TABLE_NAME` Parameters ---------- conf : :class:`pyhetdex.tools.configuration.ConfigParser` configuration. It uses the following options of the ``[database]`` section: * same as :func:`mysql_connection` date : :class:`datetime.datetime`, optional if given, use ``when`` for the ``ts`` and ``obsdate`` fields, otherwise use the current utc time ''' if not date: date = datetime.datetime.utcnow() day = date.strftime('%Y%m%d') query = """ INSERT INTO {} (ts, obsdate, inst, obsnum) VALUES (%s, %s, 'virus', 10); """ query = query.format(OBSNUM_TABLE_NAME) with mysql_connection(conf) as conn: cursor = conn.cursor() cursor.execute(query, args=(date, day)) conn.commit()