Model: gemini-2.5-flash-preview-04-17
-- none --
20ef987f10245e0cacef184dbeef40309a6dfcb86b90574d0e0172440800917c
============================= test session starts ==============================
platform darwin -- Python 3.13.3, pytest-8.3.5, pluggy-1.5.0
rootdir: /Users/simon/Dropbox/Development/sqlite-utils
plugins: anyio-4.9.0, hypothesis-6.131.9
collected 1011 items
tests/test_analyze.py .... [ 0%]
tests/test_analyze_tables.py ................ [ 1%]
tests/test_attach.py . [ 2%]
tests/test_cli.py ...................................................... [ 7%]
........................................................................ [ 14%]
..........................................sss...... [ 19%]
tests/test_cli_bulk.py ... [ 19%]
tests/test_cli_convert.py ......................................... [ 23%]
tests/test_cli_insert.py ........................................... [ 28%]
tests/test_cli_memory.py ............................ [ 30%]
tests/test_column_affinity.py .......................................... [ 35%]
............ [ 36%]
tests/test_constructor.py ...... [ 36%]
tests/test_conversions.py ...... [ 37%]
tests/test_convert.py ................ [ 39%]
tests/test_create.py ................................................... [ 44%]
......................................................s................. [ 51%]
............................... [ 54%]
tests/test_create_view.py ...... [ 54%]
tests/test_default_value.py ............. [ 56%]
tests/test_delete.py ..... [ 56%]
tests/test_docs.py ..................................................... [ 61%]
........................................... [ 66%]
tests/test_duplicate.py .. [ 66%]
tests/test_enable_counts.py ....... [ 67%]
tests/test_extract.py .......... [ 68%]
tests/test_extracts.py FFFFFF [ 68%]
tests/test_fts.py .............................................. [ 73%]
tests/test_get.py ...... [ 73%]
tests/test_gis.py ............ [ 74%]
tests/test_hypothesis.py .... [ 75%]
tests/test_insert_files.py ......... [ 76%]
tests/test_introspect.py .......................................... [ 80%]
tests/test_lookup.py .....F.. [ 81%]
tests/test_m2m.py ........... [ 82%]
tests/test_plugins.py .. [ 82%]
tests/test_query.py .. [ 82%]
tests/test_recipes.py .................. [ 84%]
tests/test_recreate.py ...... [ 85%]
tests/test_register_function.py ...... [ 85%]
tests/test_rows.py ................ [ 87%]
tests/test_rows_from_file.py ....... [ 87%]
tests/test_sniff.py .... [ 88%]
tests/test_suggest_column_types.py ................ [ 89%]
tests/test_tracer.py .. [ 90%]
tests/test_transform.py ................................................ [ 94%]
........... [ 95%]
tests/test_update.py ................... [ 97%]
tests/test_upsert.py ......... [ 98%]
tests/test_utils.py ............ [ 99%]
tests/test_wal.py . [100%]
=================================== FAILURES ===================================
_____________________ test_extracts[True-kwargs0-Species] ______________________
fresh_db = <Database <sqlite3.Connection object at 0x1051dc4f0>>
kwargs = {'extracts': {'species_id': 'Species'}}, expected_table = 'Species'
use_table_factory = True
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_Species_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
____________________ test_extracts[True-kwargs1-species_id] ____________________
fresh_db = <Database <sqlite3.Connection object at 0x1051dde40>>
kwargs = {'extracts': ['species_id']}, expected_table = 'species_id'
use_table_factory = True
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_species_id_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
____________________ test_extracts[True-kwargs2-species_id] ____________________
fresh_db = <Database <sqlite3.Connection object at 0x1051dfc40>>
kwargs = {'extracts': ('species_id',)}, expected_table = 'species_id'
use_table_factory = True
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_species_id_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
_____________________ test_extracts[False-kwargs0-Species] _____________________
fresh_db = <Database <sqlite3.Connection object at 0x1051de980>>
kwargs = {'extracts': {'species_id': 'Species'}}, expected_table = 'Species'
use_table_factory = False
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_Species_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
___________________ test_extracts[False-kwargs1-species_id] ____________________
fresh_db = <Database <sqlite3.Connection object at 0x1051de110>>
kwargs = {'extracts': ['species_id']}, expected_table = 'species_id'
use_table_factory = False
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_species_id_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
___________________ test_extracts[False-kwargs2-species_id] ____________________
fresh_db = <Database <sqlite3.Connection object at 0x1051df6a0>>
kwargs = {'extracts': ('species_id',)}, expected_table = 'species_id'
use_table_factory = False
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
> assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
E AssertionError: assert [Index(seq=0,...ns=['value'])] == []
E
E Left contains one more item: Index(seq=0, name='idx_species_id_value', unique=1, origin='c', partial=0, columns=['value'])
E Use -v to get more diff
tests/test_extracts.py:51: AssertionError
___________________ test_lookup_with_extra_insert_parameters ___________________
fresh_db = <Database <sqlite3.Connection object at 0x105162980>>
def test_lookup_with_extra_insert_parameters(fresh_db):
other_table = fresh_db["other_table"]
other_table.insert({"id": 1, "name": "Name"}, pk="id")
species = fresh_db["species"]
id = species.lookup(
{"name": "Palm", "type": "Tree"},
{
"first_seen": "2020-01-01",
"make_not_null": 1,
"fk_to_other": 1,
"default_is_dog": "cat",
"extract_this": "This is extracted",
"convert_to_upper": "upper",
"make_this_integer": "2",
"this_at_front": 1,
},
pk="renamed_id",
foreign_keys=(("fk_to_other", "other_table", "id"),),
column_order=("this_at_front",),
not_null={"make_not_null"},
defaults={"default_is_dog": "dog"},
extracts=["extract_this"],
conversions={"convert_to_upper": "upper(?)"},
columns={"make_this_integer": int},
)
assert species.schema == (
"CREATE TABLE [species] (\n"
" [renamed_id] INTEGER PRIMARY KEY,\n"
" [this_at_front] INTEGER,\n"
" [name] TEXT,\n"
" [type] TEXT,\n"
" [first_seen] TEXT,\n"
" [make_not_null] INTEGER NOT NULL,\n"
" [fk_to_other] INTEGER REFERENCES [other_table]([id]),\n"
" [default_is_dog] TEXT DEFAULT 'dog',\n"
" [extract_this] INTEGER REFERENCES [extract_this]([id]),\n"
" [convert_to_upper] TEXT,\n"
" [make_this_integer] INTEGER\n"
")"
)
> assert species.get(id) == {
"renamed_id": id,
"this_at_front": 1,
"name": "Palm",
"type": "Tree",
"first_seen": "2020-01-01",
"make_not_null": 1,
"fk_to_other": 1,
"default_is_dog": "cat",
"extract_this": 1,
"convert_to_upper": "UPPER",
"make_this_integer": 2,
}
E AssertionError: assert {'convert_to_...0-01-01', ...} == {'convert_to_...0-01-01', ...}
E
E Omitting 10 identical items, use -vv to show
E Differing items:
E {'extract_this': 'This is extracted'} != {'extract_this': 1}
E Use -v to get more diff
tests/test_lookup.py:131: AssertionError
=========================== short test summary info ============================
FAILED tests/test_extracts.py::test_extracts[True-kwargs0-Species] - Assertio...
FAILED tests/test_extracts.py::test_extracts[True-kwargs1-species_id] - Asser...
FAILED tests/test_extracts.py::test_extracts[True-kwargs2-species_id] - Asser...
FAILED tests/test_extracts.py::test_extracts[False-kwargs0-Species] - Asserti...
FAILED tests/test_extracts.py::test_extracts[False-kwargs1-species_id] - Asse...
FAILED tests/test_extracts.py::test_extracts[False-kwargs2-species_id] - Asse...
FAILED tests/test_lookup.py::test_lookup_with_extra_insert_parameters - Asser...
================== 7 failed, 1000 passed, 4 skipped in 5.86s ===================
b3dbdee1652345a085b27fcefb693ff342f808427484a9218b16798f404a05c0
<documents>
<document index="1">
<source>./churn_em_out.py</source>
<document_content>
i = 0
while True:
i += 1
print(
'{"id": I, "another": "row", "number": J}'.replace("I", str(i)).replace(
"J", str(i + 1)
)
)
</document_content>
</document>
<document index="2">
<source>./setup.py</source>
<document_content>
from setuptools import setup, find_packages
import io
import os
VERSION = "3.38"
def get_long_description():
with io.open(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"),
encoding="utf8",
) as fp:
return fp.read()
setup(
name="sqlite-utils",
description="CLI tool and Python library for manipulating SQLite databases",
long_description=get_long_description(),
long_description_content_type="text/markdown",
author="Simon Willison",
version=VERSION,
license="Apache License, Version 2.0",
packages=find_packages(exclude=["tests", "tests.*"]),
package_data={"sqlite_utils": ["py.typed"]},
install_requires=[
"sqlite-fts4",
"click",
"click-default-group>=1.2.3",
"tabulate",
"python-dateutil",
"pluggy",
],
extras_require={
"test": ["pytest", "black>=24.1.1", "hypothesis", "cogapp"],
"docs": [
"furo",
"sphinx-autobuild",
"codespell",
"sphinx-copybutton",
"beanbag-docutils>=2.0",
"pygments-csv-lexer",
],
"mypy": [
"mypy",
"types-click",
"types-tabulate",
"types-python-dateutil",
"types-pluggy",
"data-science-types",
],
"flake8": ["flake8"],
},
entry_points="""
[console_scripts]
sqlite-utils=sqlite_utils.cli:cli
""",
url="https://github.com/simonw/sqlite-utils",
project_urls={
"Documentation": "https://sqlite-utils.datasette.io/en/stable/",
"Changelog": "https://sqlite-utils.datasette.io/en/stable/changelog.html",
"Source code": "https://github.com/simonw/sqlite-utils",
"Issues": "https://github.com/simonw/sqlite-utils/issues",
"CI": "https://github.com/simonw/sqlite-utils/actions",
},
python_requires=">=3.9",
classifiers=[
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Intended Audience :: End Users/Desktop",
"Topic :: Database",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
# Needed to bundle py.typed so mypy can see it:
zip_safe=False,
)
</document_content>
</document>
<document index="3">
<source>./sqlite_utils/__init__.py</source>
<document_content>
from .utils import suggest_column_types
from .hookspecs import hookimpl
from .hookspecs import hookspec
from .db import Database
__all__ = ["Database", "suggest_column_types", "hookimpl", "hookspec"]
</document_content>
</document>
<document index="4">
<source>./sqlite_utils/__main__.py</source>
<document_content>
from .cli import cli
if __name__ == "__main__":
cli()
</document_content>
</document>
<document index="5">
<source>./sqlite_utils/cli.py</source>
<document_content>
import base64
import click
from click_default_group import DefaultGroup # type: ignore
from datetime import datetime
import hashlib
import pathlib
from runpy import run_module
import sqlite_utils
from sqlite_utils.db import AlterError, BadMultiValues, DescIndex, NoTable
from sqlite_utils.plugins import pm, get_plugins
from sqlite_utils.utils import maximize_csv_field_size_limit
from sqlite_utils import recipes
import textwrap
import inspect
import io
import itertools
import json
import os
import pdb
import sys
import csv as csv_std
import tabulate
from .utils import (
OperationalError,
_compile_code,
chunks,
file_progress,
find_spatialite,
flatten as _flatten,
sqlite3,
decode_base64_values,
progressbar,
rows_from_file,
Format,
TypeTracker,
)
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
VALID_COLUMN_TYPES = ("INTEGER", "TEXT", "FLOAT", "BLOB")
UNICODE_ERROR = """
{}
The input you provided uses a character encoding other than utf-8.
You can fix this by passing the --encoding= option with the encoding of the file.
If you do not know the encoding, running 'file filename.csv' may tell you.
It's often worth trying: --encoding=latin-1
""".strip()
maximize_csv_field_size_limit()
class CaseInsensitiveChoice(click.Choice):
def __init__(self, choices):
super().__init__([choice.lower() for choice in choices])
def convert(self, value, param, ctx):
return super().convert(value.lower(), param, ctx)
def output_options(fn):
for decorator in reversed(
(
click.option(
"--nl",
help="Output newline-delimited JSON",
is_flag=True,
default=False,
),
click.option(
"--arrays",
help="Output rows as arrays instead of objects",
is_flag=True,
default=False,
),
click.option("--csv", is_flag=True, help="Output CSV"),
click.option("--tsv", is_flag=True, help="Output TSV"),
click.option("--no-headers", is_flag=True, help="Omit CSV headers"),
click.option(
"-t", "--table", is_flag=True, help="Output as a formatted table"
),
click.option(
"--fmt",
help="Table format - one of {}".format(
", ".join(tabulate.tabulate_formats)
),
),
click.option(
"--json-cols",
help="Detect JSON cols and output them as JSON, not escaped strings",
is_flag=True,
default=False,
),
)
):
fn = decorator(fn)
return fn
def load_extension_option(fn):
return click.option(
"--load-extension",
multiple=True,
help="Path to SQLite extension, with optional :entrypoint",
)(fn)
@click.group(
cls=DefaultGroup,
default="query",
default_if_no_args=True,
context_settings=CONTEXT_SETTINGS,
)
@click.version_option()
def cli():
"Commands for interacting with a SQLite database"
pass
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.option(
"--fts4", help="Just show FTS4 enabled tables", default=False, is_flag=True
)
@click.option(
"--fts5", help="Just show FTS5 enabled tables", default=False, is_flag=True
)
@click.option(
"--counts", help="Include row counts per table", default=False, is_flag=True
)
@output_options
@click.option(
"--columns",
help="Include list of columns for each table",
is_flag=True,
default=False,
)
@click.option(
"--schema",
help="Include schema for each table",
is_flag=True,
default=False,
)
@load_extension_option
def tables(
path,
fts4,
fts5,
counts,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
columns,
schema,
load_extension,
views=False,
):
"""List the tables in the database
Example:
\b
sqlite-utils tables trees.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
headers = ["view" if views else "table"]
if counts:
headers.append("count")
if columns:
headers.append("columns")
if schema:
headers.append("schema")
def _iter():
if views:
items = db.view_names()
else:
items = db.table_names(fts4=fts4, fts5=fts5)
for name in items:
row = [name]
if counts:
row.append(db[name].count)
if columns:
cols = [c.name for c in db[name].columns]
if csv:
row.append("\n".join(cols))
else:
row.append(cols)
if schema:
row.append(db[name].schema)
yield row
if table or fmt:
print(tabulate.tabulate(_iter(), headers=headers, tablefmt=fmt or "simple"))
elif csv or tsv:
writer = csv_std.writer(sys.stdout, dialect="excel-tab" if tsv else "excel")
if not no_headers:
writer.writerow(headers)
for row in _iter():
writer.writerow(row)
else:
for line in output_rows(_iter(), headers, nl, arrays, json_cols):
click.echo(line)
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.option(
"--counts", help="Include row counts per view", default=False, is_flag=True
)
@output_options
@click.option(
"--columns",
help="Include list of columns for each view",
is_flag=True,
default=False,
)
@click.option(
"--schema",
help="Include schema for each view",
is_flag=True,
default=False,
)
@load_extension_option
def views(
path,
counts,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
columns,
schema,
load_extension,
):
"""List the views in the database
Example:
\b
sqlite-utils views trees.db
"""
tables.callback(
path=path,
fts4=False,
fts5=False,
counts=counts,
nl=nl,
arrays=arrays,
csv=csv,
tsv=tsv,
no_headers=no_headers,
table=table,
fmt=fmt,
json_cols=json_cols,
columns=columns,
schema=schema,
load_extension=load_extension,
views=True,
)
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1)
@click.option("--no-vacuum", help="Don't run VACUUM", default=False, is_flag=True)
@load_extension_option
def optimize(path, tables, no_vacuum, load_extension):
"""Optimize all full-text search tables and then run VACUUM - should shrink the database file
Example:
\b
sqlite-utils optimize chickens.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if not tables:
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
with db.conn:
for table in tables:
db[table].optimize()
if not no_vacuum:
db.vacuum()
@cli.command(name="rebuild-fts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1)
@load_extension_option
def rebuild_fts(path, tables, load_extension):
"""Rebuild all or specific full-text search tables
Example:
\b
sqlite-utils rebuild-fts chickens.db chickens
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if not tables:
tables = db.table_names(fts4=True) + db.table_names(fts5=True)
with db.conn:
for table in tables:
db[table].rebuild_fts()
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("names", nargs=-1)
def analyze(path, names):
"""Run ANALYZE against the whole database, or against specific named indexes and tables
Example:
\b
sqlite-utils analyze chickens.db
"""
db = sqlite_utils.Database(path)
try:
if names:
for name in names:
db.analyze(name)
else:
db.analyze()
except OperationalError as e:
raise click.ClickException(e)
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
def vacuum(path):
"""Run VACUUM against the database
Example:
\b
sqlite-utils vacuum chickens.db
"""
sqlite_utils.Database(path).vacuum()
@cli.command()
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@load_extension_option
def dump(path, load_extension):
"""Output a SQL dump of the schema and full contents of the database
Example:
\b
sqlite-utils dump chickens.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
for line in db.iterdump():
click.echo(line)
@cli.command(name="add-column")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("col_name")
@click.argument(
"col_type",
type=click.Choice(
["integer", "int", "float", "text", "str", "blob", "bytes"],
case_sensitive=False,
),
required=False,
)
@click.option(
"--fk", type=str, required=False, help="Table to reference as a foreign key"
)
@click.option(
"--fk-col",
type=str,
required=False,
help="Referenced column on that foreign key table - if omitted will automatically use the primary key",
)
@click.option(
"--not-null-default",
type=str,
required=False,
help="Add NOT NULL DEFAULT 'TEXT' constraint",
)
@click.option(
"--ignore",
is_flag=True,
help="If column already exists, do nothing",
)
@load_extension_option
def add_column(
path,
table,
col_name,
col_type,
fk,
fk_col,
not_null_default,
ignore,
load_extension,
):
"""Add a column to the specified table
Example:
\b
sqlite-utils add-column chickens.db chickens weight float
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[table].add_column(
col_name, col_type, fk=fk, fk_col=fk_col, not_null_default=not_null_default
)
except OperationalError as ex:
if not ignore:
raise click.ClickException(str(ex))
@cli.command(name="add-foreign-key")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("column")
@click.argument("other_table", required=False)
@click.argument("other_column", required=False)
@click.option(
"--ignore",
is_flag=True,
help="If foreign key already exists, do nothing",
)
@load_extension_option
def add_foreign_key(
path, table, column, other_table, other_column, ignore, load_extension
):
"""
Add a new foreign key constraint to an existing table
Example:
sqlite-utils add-foreign-key my.db books author_id authors id
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[table].add_foreign_key(column, other_table, other_column, ignore=ignore)
except AlterError as e:
raise click.ClickException(e)
@cli.command(name="add-foreign-keys")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("foreign_key", nargs=-1)
@load_extension_option
def add_foreign_keys(path, foreign_key, load_extension):
"""
Add multiple new foreign key constraints to a database
Example:
\b
sqlite-utils add-foreign-keys my.db \\
books author_id authors id \\
authors country_id countries id
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if len(foreign_key) % 4 != 0:
raise click.ClickException(
"Each foreign key requires four values: table, column, other_table, other_column"
)
tuples = []
for i in range(len(foreign_key) // 4):
tuples.append(tuple(foreign_key[i * 4 : (i * 4) + 4]))
try:
db.add_foreign_keys(tuples)
except AlterError as e:
raise click.ClickException(e)
@cli.command(name="index-foreign-keys")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@load_extension_option
def index_foreign_keys(path, load_extension):
"""
Ensure every foreign key column has an index on it
Example:
\b
sqlite-utils index-foreign-keys chickens.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
db.index_foreign_keys()
@cli.command(name="create-index")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("column", nargs=-1, required=True)
@click.option("--name", help="Explicit name for the new index")
@click.option("--unique", help="Make this a unique index", default=False, is_flag=True)
@click.option(
"--if-not-exists",
"--ignore",
help="Ignore if index already exists",
default=False,
is_flag=True,
)
@click.option(
"--analyze",
help="Run ANALYZE after creating the index",
is_flag=True,
)
@load_extension_option
def create_index(
path, table, column, name, unique, if_not_exists, analyze, load_extension
):
"""
Add an index to the specified table for the specified columns
Example:
\b
sqlite-utils create-index chickens.db chickens name
To create an index in descending order:
\b
sqlite-utils create-index chickens.db chickens -- -name
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
# Treat -prefix as descending for columns
columns = []
for col in column:
if col.startswith("-"):
col = DescIndex(col[1:])
columns.append(col)
db[table].create_index(
columns,
index_name=name,
unique=unique,
if_not_exists=if_not_exists,
analyze=analyze,
)
@cli.command(name="enable-fts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("column", nargs=-1, required=True)
@click.option("--fts4", help="Use FTS4", default=False, is_flag=True)
@click.option("--fts5", help="Use FTS5", default=False, is_flag=True)
@click.option("--tokenize", help="Tokenizer to use, e.g. porter")
@click.option(
"--create-triggers",
help="Create triggers to update the FTS tables when the parent table changes.",
default=False,
is_flag=True,
)
@click.option(
"--replace",
is_flag=True,
help="Replace existing FTS configuration if it exists",
)
@load_extension_option
def enable_fts(
path, table, column, fts4, fts5, tokenize, create_triggers, replace, load_extension
):
"""Enable full-text search for specific table and columns"
Example:
\b
sqlite-utils enable-fts chickens.db chickens name
"""
fts_version = "FTS5"
if fts4 and fts5:
click.echo("Can only use one of --fts4 or --fts5", err=True)
return
elif fts4:
fts_version = "FTS4"
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[table].enable_fts(
column,
fts_version=fts_version,
tokenize=tokenize,
create_triggers=create_triggers,
replace=replace,
)
except OperationalError as ex:
raise click.ClickException(ex)
@cli.command(name="populate-fts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("column", nargs=-1, required=True)
@load_extension_option
def populate_fts(path, table, column, load_extension):
"""Re-populate full-text search for specific table and columns
Example:
\b
sqlite-utils populate-fts chickens.db chickens name
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
db[table].populate_fts(column)
@cli.command(name="disable-fts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@load_extension_option
def disable_fts(path, table, load_extension):
"""Disable full-text search for specific table
Example:
\b
sqlite-utils disable-fts chickens.db chickens
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
db[table].disable_fts()
@cli.command(name="enable-wal")
@click.argument(
"path",
nargs=-1,
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@load_extension_option
def enable_wal(path, load_extension):
"""Enable WAL for database files
Example:
\b
sqlite-utils enable-wal chickens.db
"""
for path_ in path:
db = sqlite_utils.Database(path_)
_load_extensions(db, load_extension)
db.enable_wal()
@cli.command(name="disable-wal")
@click.argument(
"path",
nargs=-1,
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@load_extension_option
def disable_wal(path, load_extension):
"""Disable WAL for database files
Example:
\b
sqlite-utils disable-wal chickens.db
"""
for path_ in path:
db = sqlite_utils.Database(path_)
_load_extensions(db, load_extension)
db.disable_wal()
@cli.command(name="enable-counts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1)
@load_extension_option
def enable_counts(path, tables, load_extension):
"""Configure triggers to update a _counts table with row counts
Example:
\b
sqlite-utils enable-counts chickens.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if not tables:
db.enable_counts()
else:
# Check all tables exist
bad_tables = [table for table in tables if not db[table].exists()]
if bad_tables:
raise click.ClickException("Invalid tables: {}".format(bad_tables))
for table in tables:
db[table].enable_counts()
@cli.command(name="reset-counts")
@click.argument(
"path",
type=click.Path(exists=True, file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@load_extension_option
def reset_counts(path, load_extension):
"""Reset calculated counts in the _counts table
Example:
\b
sqlite-utils reset-counts chickens.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
db.reset_counts()
_import_options = (
click.option(
"--flatten",
is_flag=True,
help='Flatten nested JSON objects, so {"a": {"b": 1}} becomes {"a_b": 1}',
),
click.option("--nl", is_flag=True, help="Expect newline-delimited JSON"),
click.option("-c", "--csv", is_flag=True, help="Expect CSV input"),
click.option("--tsv", is_flag=True, help="Expect TSV input"),
click.option("--empty-null", is_flag=True, help="Treat empty strings as NULL"),
click.option(
"--lines",
is_flag=True,
help="Treat each line as a single value called 'line'",
),
click.option(
"--text",
is_flag=True,
help="Treat input as a single value called 'text'",
),
click.option("--convert", help="Python code to convert each item"),
click.option(
"--import",
"imports",
type=str,
multiple=True,
help="Python modules to import",
),
click.option("--delimiter", help="Delimiter to use for CSV files"),
click.option("--quotechar", help="Quote character to use for CSV/TSV"),
click.option("--sniff", is_flag=True, help="Detect delimiter and quote character"),
click.option("--no-headers", is_flag=True, help="CSV file has no header row"),
click.option(
"--encoding",
help="Character encoding for input, defaults to utf-8",
),
)
def import_options(fn):
for decorator in reversed(_import_options):
fn = decorator(fn)
return fn
def insert_upsert_options(*, require_pk=False):
def inner(fn):
for decorator in reversed(
(
click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
),
click.argument("table"),
click.argument("file", type=click.File("rb"), required=True),
click.option(
"--pk",
help="Columns to use as the primary key, e.g. id",
multiple=True,
required=require_pk,
),
)
+ _import_options
+ (
click.option(
"--batch-size", type=int, default=100, help="Commit every X records"
),
click.option("--stop-after", type=int, help="Stop after X records"),
click.option(
"--alter",
is_flag=True,
help="Alter existing table to add any missing columns",
),
click.option(
"--not-null",
multiple=True,
help="Columns that should be created as NOT NULL",
),
click.option(
"--default",
multiple=True,
type=(str, str),
help="Default value that should be set for a column",
),
click.option(
"-d",
"--detect-types",
is_flag=True,
envvar="SQLITE_UTILS_DETECT_TYPES",
help="Detect types for columns in CSV/TSV data",
),
click.option(
"--analyze",
is_flag=True,
help="Run ANALYZE at the end of this operation",
),
load_extension_option,
click.option("--silent", is_flag=True, help="Do not show progress bar"),
click.option(
"--strict",
is_flag=True,
default=False,
help="Apply STRICT mode to created table",
),
)
):
fn = decorator(fn)
return fn
return inner
def insert_upsert_implementation(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
batch_size,
stop_after,
alter,
upsert,
ignore=False,
replace=False,
truncate=False,
not_null=None,
default=None,
detect_types=None,
analyze=False,
load_extension=None,
silent=False,
bulk_sql=None,
functions=None,
strict=False,
):
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if functions:
_register_functions(db, functions)
if (delimiter or quotechar or sniff or no_headers) and not tsv:
csv = True
if (nl + csv + tsv) >= 2:
raise click.ClickException("Use just one of --nl, --csv or --tsv")
if (csv or tsv) and flatten:
raise click.ClickException("--flatten cannot be used with --csv or --tsv")
if empty_null and not (csv or tsv):
raise click.ClickException("--empty-null can only be used with --csv or --tsv")
if encoding and not (csv or tsv):
raise click.ClickException("--encoding must be used with --csv or --tsv")
if pk and len(pk) == 1:
pk = pk[0]
encoding = encoding or "utf-8-sig"
# The --sniff option needs us to buffer the file to peek ahead
sniff_buffer = None
decoded_buffer = None
if sniff:
sniff_buffer = io.BufferedReader(file, buffer_size=4096)
decoded_buffer = io.TextIOWrapper(sniff_buffer, encoding=encoding)
else:
decoded_buffer = io.TextIOWrapper(file, encoding=encoding)
tracker = None
with file_progress(decoded_buffer, silent=silent) as decoded:
if csv or tsv:
if sniff:
# Read first 2048 bytes and use that to detect
first_bytes = sniff_buffer.peek(2048)
dialect = csv_std.Sniffer().sniff(
first_bytes.decode(encoding, "ignore")
)
else:
dialect = "excel-tab" if tsv else "excel"
csv_reader_args = {"dialect": dialect}
if delimiter:
csv_reader_args["delimiter"] = delimiter
if quotechar:
csv_reader_args["quotechar"] = quotechar
reader = csv_std.reader(decoded, **csv_reader_args)
first_row = next(reader)
if no_headers:
headers = ["untitled_{}".format(i + 1) for i in range(len(first_row))]
reader = itertools.chain([first_row], reader)
else:
headers = first_row
if empty_null:
docs = (
dict(zip(headers, [None if cell == "" else cell for cell in row]))
for row in reader
)
else:
docs = (dict(zip(headers, row)) for row in reader)
if detect_types:
tracker = TypeTracker()
docs = tracker.wrap(docs)
elif lines:
docs = ({"line": line.strip()} for line in decoded)
elif text:
docs = ({"text": decoded.read()},)
else:
try:
if nl:
docs = (json.loads(line) for line in decoded if line.strip())
else:
docs = json.load(decoded)
if isinstance(docs, dict):
docs = [docs]
except json.decoder.JSONDecodeError as ex:
raise click.ClickException(
"Invalid JSON - use --csv for CSV or --tsv for TSV files\n\nJSON error: {}".format(
ex
)
)
if flatten:
docs = (_flatten(doc) for doc in docs)
if stop_after:
docs = itertools.islice(docs, stop_after)
if convert:
variable = "row"
if lines:
variable = "line"
elif text:
variable = "text"
fn = _compile_code(convert, imports, variable=variable)
if lines:
docs = (fn(doc["line"]) for doc in docs)
elif text:
# Special case: this is allowed to be an iterable
text_value = list(docs)[0]["text"]
fn_return = fn(text_value)
if isinstance(fn_return, dict):
docs = [fn_return]
else:
try:
docs = iter(fn_return)
except TypeError:
raise click.ClickException(
"--convert must return dict or iterator"
)
else:
docs = (fn(doc) or doc for doc in docs)
extra_kwargs = {
"ignore": ignore,
"replace": replace,
"truncate": truncate,
"analyze": analyze,
"strict": strict,
}
if not_null:
extra_kwargs["not_null"] = set(not_null)
if default:
extra_kwargs["defaults"] = dict(default)
if upsert:
extra_kwargs["upsert"] = upsert
# docs should all be dictionaries
docs = (verify_is_dict(doc) for doc in docs)
# Apply {"$base64": true, ...} decoding, if needed
docs = (decode_base64_values(doc) for doc in docs)
# For bulk_sql= we use cursor.executemany() instead
if bulk_sql:
if batch_size:
doc_chunks = chunks(docs, batch_size)
else:
doc_chunks = [docs]
for doc_chunk in doc_chunks:
with db.conn:
db.conn.cursor().executemany(bulk_sql, doc_chunk)
return
try:
db[table].insert_all(
docs, pk=pk, batch_size=batch_size, alter=alter, **extra_kwargs
)
except Exception as e:
if (
isinstance(e, OperationalError)
and e.args
and "has no column named" in e.args[0]
):
raise click.ClickException(
"{}\n\nTry using --alter to add additional columns".format(
e.args[0]
)
)
# If we can find sql= and parameters= arguments, show those
variables = _find_variables(e.__traceback__, ["sql", "parameters"])
if "sql" in variables and "parameters" in variables:
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(
str(e), variables["sql"], variables["parameters"]
)
)
else:
raise
if tracker is not None:
db[table].transform(types=tracker.types)
# Clean up open file-like objects
if sniff_buffer:
sniff_buffer.close()
if decoded_buffer:
decoded_buffer.close()
def _find_variables(tb, vars):
to_find = list(vars)
found = {}
for var in to_find:
if var in tb.tb_frame.f_locals:
vars.remove(var)
found[var] = tb.tb_frame.f_locals[var]
if vars and tb.tb_next:
found.update(_find_variables(tb.tb_next, vars))
return found
@cli.command()
@insert_upsert_options()
@click.option(
"--ignore", is_flag=True, default=False, help="Ignore records if pk already exists"
)
@click.option(
"--replace",
is_flag=True,
default=False,
help="Replace records if pk already exists",
)
@click.option(
"--truncate",
is_flag=True,
default=False,
help="Truncate table before inserting records, if table already exists",
)
def insert(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
batch_size,
stop_after,
alter,
detect_types,
analyze,
load_extension,
silent,
ignore,
replace,
truncate,
not_null,
default,
strict,
):
"""
Insert records from FILE into a table, creating the table if it
does not already exist.
Example:
echo '{"name": "Lila"}' | sqlite-utils insert data.db chickens -
By default the input is expected to be a JSON object or array of objects.
\b
- Use --nl for newline-delimited JSON objects
- Use --csv or --tsv for comma-separated or tab-separated input
- Use --lines to write each incoming line to a column called "line"
- Use --text to write the entire input to a column called "text"
You can also use --convert to pass a fragment of Python code that will
be used to convert each input.
Your Python code will be passed a "row" variable representing the
imported row, and can return a modified row.
This example uses just the name, latitude and longitude columns from
a CSV file, converting name to upper case and latitude and longitude
to floating point numbers:
\b
sqlite-utils insert plants.db plants plants.csv --csv --convert '
return {
"name": row["name"].upper(),
"latitude": float(row["latitude"]),
"longitude": float(row["longitude"]),
}'
If you are using --lines your code will be passed a "line" variable,
and for --text a "text" variable.
When using --text your function can return an iterator of rows to
insert. This example inserts one record per word in the input:
\b
echo 'A bunch of words' | sqlite-utils insert words.db words - \\
--text --convert '({"word": w} for w in text.split())'
"""
try:
insert_upsert_implementation(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
batch_size,
stop_after,
alter=alter,
upsert=False,
ignore=ignore,
replace=replace,
truncate=truncate,
detect_types=detect_types,
analyze=analyze,
load_extension=load_extension,
silent=silent,
not_null=not_null,
default=default,
strict=strict,
)
except UnicodeDecodeError as ex:
raise click.ClickException(UNICODE_ERROR.format(ex))
@cli.command()
@insert_upsert_options(require_pk=True)
def upsert(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
batch_size,
stop_after,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
alter,
not_null,
default,
detect_types,
analyze,
load_extension,
silent,
strict,
):
"""
Upsert records based on their primary key. Works like 'insert' but if
an incoming record has a primary key that matches an existing record
the existing record will be updated.
Example:
\b
echo '[
{"id": 1, "name": "Lila"},
{"id": 2, "name": "Suna"}
]' | sqlite-utils upsert data.db chickens - --pk id
"""
try:
insert_upsert_implementation(
path,
table,
file,
pk,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
batch_size,
stop_after,
alter=alter,
upsert=True,
not_null=not_null,
default=default,
detect_types=detect_types,
analyze=analyze,
load_extension=load_extension,
silent=silent,
strict=strict,
)
except UnicodeDecodeError as ex:
raise click.ClickException(UNICODE_ERROR.format(ex))
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("sql")
@click.argument("file", type=click.File("rb"), required=True)
@click.option("--batch-size", type=int, default=100, help="Commit every X records")
@click.option(
"--functions", help="Python code defining one or more custom SQL functions"
)
@import_options
@load_extension_option
def bulk(
path,
sql,
file,
batch_size,
functions,
flatten,
nl,
csv,
tsv,
empty_null,
lines,
text,
convert,
imports,
delimiter,
quotechar,
sniff,
no_headers,
encoding,
load_extension,
):
"""
Execute parameterized SQL against the provided list of documents.
Example:
\b
echo '[
{"id": 1, "name": "Lila2"},
{"id": 2, "name": "Suna2"}
]' | sqlite-utils bulk data.db '
update chickens set name = :name where id = :id
' -
"""
try:
insert_upsert_implementation(
path=path,
table=None,
file=file,
pk=None,
flatten=flatten,
nl=nl,
csv=csv,
tsv=tsv,
empty_null=empty_null,
lines=lines,
text=text,
convert=convert,
imports=imports,
delimiter=delimiter,
quotechar=quotechar,
sniff=sniff,
no_headers=no_headers,
encoding=encoding,
batch_size=batch_size,
stop_after=None,
alter=False,
upsert=False,
not_null=set(),
default={},
detect_types=False,
load_extension=load_extension,
silent=False,
bulk_sql=sql,
functions=functions,
)
except (OperationalError, sqlite3.IntegrityError) as e:
raise click.ClickException(str(e))
@cli.command(name="create-database")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.option(
"--enable-wal", is_flag=True, help="Enable WAL mode on the created database"
)
@click.option(
"--init-spatialite", is_flag=True, help="Enable SpatiaLite on the created database"
)
@load_extension_option
def create_database(path, enable_wal, init_spatialite, load_extension):
"""Create a new empty database file
Example:
\b
sqlite-utils create-database trees.db
"""
db = sqlite_utils.Database(path)
if enable_wal:
db.enable_wal()
# load spatialite or another extension from a custom location
if load_extension:
_load_extensions(db, load_extension)
# load spatialite from expected locations and initialize metadata
if init_spatialite:
db.init_spatialite()
db.vacuum()
@cli.command(name="create-table")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("columns", nargs=-1, required=True)
@click.option("pks", "--pk", help="Column to use as primary key", multiple=True)
@click.option(
"--not-null",
multiple=True,
help="Columns that should be created as NOT NULL",
)
@click.option(
"--default",
multiple=True,
type=(str, str),
help="Default value that should be set for a column",
)
@click.option(
"--fk",
multiple=True,
type=(str, str, str),
help="Column, other table, other column to set as a foreign key",
)
@click.option(
"--ignore",
is_flag=True,
help="If table already exists, do nothing",
)
@click.option(
"--replace",
is_flag=True,
help="If table already exists, replace it",
)
@click.option(
"--transform",
is_flag=True,
help="If table already exists, try to transform the schema",
)
@load_extension_option
@click.option(
"--strict",
is_flag=True,
help="Apply STRICT mode to created table",
)
def create_table(
path,
table,
columns,
pks,
not_null,
default,
fk,
ignore,
replace,
transform,
load_extension,
strict,
):
"""
Add a table with the specified columns. Columns should be specified using
name, type pairs, for example:
\b
sqlite-utils create-table my.db people \\
id integer \\
name text \\
height float \\
photo blob --pk id
Valid column types are text, integer, float and blob.
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if len(columns) % 2 == 1:
raise click.ClickException(
"columns must be an even number of 'name' 'type' pairs"
)
coltypes = {}
columns = list(columns)
while columns:
name = columns.pop(0)
ctype = columns.pop(0)
if ctype.upper() not in VALID_COLUMN_TYPES:
raise click.ClickException(
"column types must be one of {}".format(VALID_COLUMN_TYPES)
)
coltypes[name] = ctype.upper()
# Does table already exist?
if table in db.table_names():
if not ignore and not replace and not transform:
raise click.ClickException(
'Table "{}" already exists. Use --replace to delete and replace it.'.format(
table
)
)
db[table].create(
coltypes,
pk=pks[0] if len(pks) == 1 else pks,
not_null=not_null,
defaults=dict(default),
foreign_keys=fk,
ignore=ignore,
replace=replace,
transform=transform,
strict=strict,
)
@cli.command(name="duplicate")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("new_table")
@click.option("--ignore", is_flag=True, help="If table does not exist, do nothing")
@load_extension_option
def duplicate(path, table, new_table, ignore, load_extension):
"""
Create a duplicate of this table, copying across the schema and all row data.
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[table].duplicate(new_table)
except NoTable:
if not ignore:
raise click.ClickException('Table "{}" does not exist'.format(table))
@cli.command(name="rename-table")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("new_name")
@click.option("--ignore", is_flag=True, help="If table does not exist, do nothing")
@load_extension_option
def rename_table(path, table, new_name, ignore, load_extension):
"""
Rename this table.
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db.rename_table(table, new_name)
except sqlite3.OperationalError as ex:
if not ignore:
raise click.ClickException(
'Table "{}" could not be renamed. {}'.format(table, str(ex))
)
@cli.command(name="drop-table")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.option("--ignore", is_flag=True, help="If table does not exist, do nothing")
@load_extension_option
def drop_table(path, table, ignore, load_extension):
"""Drop the specified table
Example:
\b
sqlite-utils drop-table chickens.db chickens
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[table].drop(ignore=ignore)
except OperationalError:
raise click.ClickException('Table "{}" does not exist'.format(table))
@cli.command(name="create-view")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("view")
@click.argument("select")
@click.option(
"--ignore",
is_flag=True,
help="If view already exists, do nothing",
)
@click.option(
"--replace",
is_flag=True,
help="If view already exists, replace it",
)
@load_extension_option
def create_view(path, view, select, ignore, replace, load_extension):
"""Create a view for the provided SELECT query
Example:
\b
sqlite-utils create-view chickens.db heavy_chickens \\
'select * from chickens where weight > 3'
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
# Does view already exist?
if view in db.view_names():
if ignore:
return
elif replace:
db[view].drop()
else:
raise click.ClickException(
'View "{}" already exists. Use --replace to delete and replace it.'.format(
view
)
)
db.create_view(view, select)
@cli.command(name="drop-view")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("view")
@click.option("--ignore", is_flag=True, help="If view does not exist, do nothing")
@load_extension_option
def drop_view(path, view, ignore, load_extension):
"""Drop the specified view
Example:
\b
sqlite-utils drop-view chickens.db heavy_chickens
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
db[view].drop(ignore=ignore)
except OperationalError:
raise click.ClickException('View "{}" does not exist'.format(view))
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("sql")
@click.option(
"--attach",
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
multiple=True,
help="Additional databases to attach - specify alias and filepath",
)
@output_options
@click.option("-r", "--raw", is_flag=True, help="Raw output, first column of first row")
@click.option("--raw-lines", is_flag=True, help="Raw output, first column of each row")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Named :parameters for SQL query",
)
@click.option(
"--functions", help="Python code defining one or more custom SQL functions"
)
@load_extension_option
def query(
path,
sql,
attach,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
raw,
raw_lines,
param,
load_extension,
functions,
):
"""Execute SQL query and return the results as JSON
Example:
\b
sqlite-utils data.db \\
"select * from chickens where age > :age" \\
-p age 1
"""
db = sqlite_utils.Database(path)
for alias, attach_path in attach:
db.attach(alias, attach_path)
_load_extensions(db, load_extension)
db.register_fts4_bm25()
if functions:
_register_functions(db, functions)
_execute_query(
db,
sql,
param,
raw,
raw_lines,
table,
csv,
tsv,
no_headers,
fmt,
nl,
arrays,
json_cols,
)
@cli.command()
@click.argument(
"paths",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=True),
required=False,
nargs=-1,
)
@click.argument("sql")
@click.option(
"--functions", help="Python code defining one or more custom SQL functions"
)
@click.option(
"--attach",
type=(str, click.Path(file_okay=True, dir_okay=False, allow_dash=False)),
multiple=True,
help="Additional databases to attach - specify alias and filepath",
)
@click.option(
"--flatten",
is_flag=True,
help='Flatten nested JSON objects, so {"foo": {"bar": 1}} becomes {"foo_bar": 1}',
)
@output_options
@click.option("-r", "--raw", is_flag=True, help="Raw output, first column of first row")
@click.option("--raw-lines", is_flag=True, help="Raw output, first column of each row")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Named :parameters for SQL query",
)
@click.option(
"--encoding",
help="Character encoding for CSV input, defaults to utf-8",
)
@click.option(
"-n",
"--no-detect-types",
is_flag=True,
help="Treat all CSV/TSV columns as TEXT",
)
@click.option("--schema", is_flag=True, help="Show SQL schema for in-memory database")
@click.option("--dump", is_flag=True, help="Dump SQL for in-memory database")
@click.option(
"--save",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
help="Save in-memory database to this file",
)
@click.option(
"--analyze",
is_flag=True,
help="Analyze resulting tables and output results",
)
@load_extension_option
def memory(
paths,
sql,
functions,
attach,
flatten,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
raw,
raw_lines,
param,
encoding,
no_detect_types,
schema,
dump,
save,
analyze,
load_extension,
return_db=False,
):
"""Execute SQL query against an in-memory database, optionally populated by imported data
To import data from CSV, TSV or JSON files pass them on the command-line:
\b
sqlite-utils memory one.csv two.json \\
"select * from one join two on one.two_id = two.id"
For data piped into the tool from standard input, use "-" or "stdin":
\b
cat animals.csv | sqlite-utils memory - \\
"select * from stdin where species = 'dog'"
The format of the data will be automatically detected. You can specify the format
explicitly using :json, :csv, :tsv or :nl (for newline-delimited JSON) - for example:
\b
cat animals.csv | sqlite-utils memory stdin:csv places.dat:nl \\
"select * from stdin where place_id in (select id from places)"
Use --schema to view the SQL schema of any imported files:
\b
sqlite-utils memory animals.csv --schema
"""
db = sqlite_utils.Database(memory=True)
# If --dump or --save or --analyze used but no paths detected, assume SQL query is a path:
if (dump or save or schema or analyze) and not paths:
paths = [sql]
sql = None
stem_counts = {}
for i, path in enumerate(paths):
# Path may have a :format suffix
fp = None
if ":" in path and path.rsplit(":", 1)[-1].upper() in Format.__members__:
path, suffix = path.rsplit(":", 1)
format = Format[suffix.upper()]
else:
format = None
if path in ("-", "stdin"):
fp = sys.stdin.buffer
file_table = "stdin"
else:
file_path = pathlib.Path(path)
stem = file_path.stem
if stem_counts.get(stem):
file_table = "{}_{}".format(stem, stem_counts[stem])
else:
file_table = stem
stem_counts[stem] = stem_counts.get(stem, 1) + 1
fp = file_path.open("rb")
rows, format_used = rows_from_file(fp, format=format, encoding=encoding)
tracker = None
if format_used in (Format.CSV, Format.TSV) and not no_detect_types:
tracker = TypeTracker()
rows = tracker.wrap(rows)
if flatten:
rows = (_flatten(row) for row in rows)
db[file_table].insert_all(rows, alter=True)
if tracker is not None:
db[file_table].transform(types=tracker.types)
# Add convenient t / t1 / t2 views
view_names = ["t{}".format(i + 1)]
if i == 0:
view_names.append("t")
for view_name in view_names:
if not db[view_name].exists():
db.create_view(view_name, "select * from [{}]".format(file_table))
if fp:
fp.close()
if analyze:
_analyze(db, tables=None, columns=None, save=False)
return
if dump:
for line in db.iterdump():
click.echo(line)
return
if schema:
click.echo(db.schema)
return
if save:
db2 = sqlite_utils.Database(save)
for line in db.iterdump():
db2.execute(line)
return
for alias, attach_path in attach:
db.attach(alias, attach_path)
_load_extensions(db, load_extension)
db.register_fts4_bm25()
if functions:
_register_functions(db, functions)
if return_db:
return db
_execute_query(
db,
sql,
param,
raw,
raw_lines,
table,
csv,
tsv,
no_headers,
fmt,
nl,
arrays,
json_cols,
)
def _execute_query(
db,
sql,
param,
raw,
raw_lines,
table,
csv,
tsv,
no_headers,
fmt,
nl,
arrays,
json_cols,
):
with db.conn:
try:
cursor = db.execute(sql, dict(param))
except OperationalError as e:
raise click.ClickException(str(e))
if cursor.description is None:
# This was an update/insert
headers = ["rows_affected"]
cursor = [[cursor.rowcount]]
else:
headers = [c[0] for c in cursor.description]
if raw:
data = cursor.fetchone()[0]
if isinstance(data, bytes):
sys.stdout.buffer.write(data)
else:
sys.stdout.write(str(data))
elif raw_lines:
for row in cursor:
data = row[0]
if isinstance(data, bytes):
sys.stdout.buffer.write(data + b"\n")
else:
sys.stdout.write(str(data) + "\n")
elif fmt or table:
print(
tabulate.tabulate(
list(cursor), headers=headers, tablefmt=fmt or "simple"
)
)
elif csv or tsv:
writer = csv_std.writer(sys.stdout, dialect="excel-tab" if tsv else "excel")
if not no_headers:
writer.writerow(headers)
for row in cursor:
writer.writerow(row)
else:
for line in output_rows(cursor, headers, nl, arrays, json_cols):
click.echo(line)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("dbtable")
@click.argument("q")
@click.option("-o", "--order", type=str, help="Order by ('column' or 'column desc')")
@click.option("-c", "--column", type=str, multiple=True, help="Columns to return")
@click.option(
"--limit",
type=int,
help="Number of rows to return - defaults to everything",
)
@click.option(
"--sql", "show_sql", is_flag=True, help="Show SQL query that would be run"
)
@click.option("--quote", is_flag=True, help="Apply FTS quoting rules to search term")
@output_options
@load_extension_option
@click.pass_context
def search(
ctx,
path,
dbtable,
q,
order,
show_sql,
quote,
column,
limit,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
load_extension,
):
"""Execute a full-text search against this table
Example:
sqlite-utils search data.db chickens lila
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
# Check table exists
table_obj = db[dbtable]
if not table_obj.exists():
raise click.ClickException("Table '{}' does not exist".format(dbtable))
if not table_obj.detect_fts():
raise click.ClickException(
"Table '{}' is not configured for full-text search".format(dbtable)
)
if column:
# Check they all exist
table_columns = table_obj.columns_dict
for c in column:
if c not in table_columns:
raise click.ClickException(
"Table '{}' has no column '{}".format(dbtable, c)
)
sql = table_obj.search_sql(columns=column, order_by=order, limit=limit)
if show_sql:
click.echo(sql)
return
if quote:
q = db.quote_fts(q)
try:
ctx.invoke(
query,
path=path,
sql=sql,
nl=nl,
arrays=arrays,
csv=csv,
tsv=tsv,
no_headers=no_headers,
table=table,
fmt=fmt,
json_cols=json_cols,
param=[("query", q)],
load_extension=load_extension,
)
except click.ClickException as e:
if "malformed MATCH expression" in str(e) or "unterminated string" in str(e):
raise click.ClickException(
"{}\n\nTry running this again with the --quote option".format(str(e))
)
else:
raise
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("dbtable")
@click.option("-c", "--column", type=str, multiple=True, help="Columns to return")
@click.option("--where", help="Optional where clause")
@click.option("-o", "--order", type=str, help="Order by ('column' or 'column desc')")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Named :parameters for where clause",
)
@click.option(
"--limit",
type=int,
help="Number of rows to return - defaults to everything",
)
@click.option(
"--offset",
type=int,
help="SQL offset to use",
)
@output_options
@load_extension_option
@click.pass_context
def rows(
ctx,
path,
dbtable,
column,
where,
order,
param,
limit,
offset,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
load_extension,
):
"""Output all rows in the specified table
Example:
\b
sqlite-utils rows trees.db Trees
"""
columns = "*"
if column:
columns = ", ".join("[{}]".format(c) for c in column)
sql = "select {} from [{}]".format(columns, dbtable)
if where:
sql += " where " + where
if order:
sql += " order by " + order
if limit:
sql += " limit {}".format(limit)
if offset:
sql += " offset {}".format(offset)
ctx.invoke(
query,
path=path,
sql=sql,
nl=nl,
arrays=arrays,
csv=csv,
tsv=tsv,
no_headers=no_headers,
table=table,
fmt=fmt,
param=param,
json_cols=json_cols,
load_extension=load_extension,
)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1)
@output_options
@load_extension_option
@click.pass_context
def triggers(
ctx,
path,
tables,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
load_extension,
):
"""Show triggers configured in this database
Example:
\b
sqlite-utils triggers trees.db
"""
sql = "select name, tbl_name as [table], sql from sqlite_master where type = 'trigger'"
if tables:
quote = sqlite_utils.Database(memory=True).quote
sql += " and [table] in ({})".format(
", ".join(quote(table) for table in tables)
)
ctx.invoke(
query,
path=path,
sql=sql,
nl=nl,
arrays=arrays,
csv=csv,
tsv=tsv,
no_headers=no_headers,
table=table,
fmt=fmt,
json_cols=json_cols,
load_extension=load_extension,
)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1)
@click.option("--aux", is_flag=True, help="Include auxiliary columns")
@output_options
@load_extension_option
@click.pass_context
def indexes(
ctx,
path,
tables,
aux,
nl,
arrays,
csv,
tsv,
no_headers,
table,
fmt,
json_cols,
load_extension,
):
"""Show indexes for the whole database or specific tables
Example:
\b
sqlite-utils indexes trees.db Trees
"""
sql = """
select
sqlite_master.name as "table",
indexes.name as index_name,
xinfo.*
from sqlite_master
join pragma_index_list(sqlite_master.name) indexes
join pragma_index_xinfo(index_name) xinfo
where
sqlite_master.type = 'table'
"""
if tables:
quote = sqlite_utils.Database(memory=True).quote
sql += " and sqlite_master.name in ({})".format(
", ".join(quote(table) for table in tables)
)
if not aux:
sql += " and xinfo.key = 1"
ctx.invoke(
query,
path=path,
sql=sql,
nl=nl,
arrays=arrays,
csv=csv,
tsv=tsv,
no_headers=no_headers,
table=table,
fmt=fmt,
json_cols=json_cols,
load_extension=load_extension,
)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("tables", nargs=-1, required=False)
@load_extension_option
def schema(
path,
tables,
load_extension,
):
"""Show full schema for this database or for specified tables
Example:
\b
sqlite-utils schema trees.db
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
if tables:
for table in tables:
click.echo(db[table].schema)
else:
click.echo(db.schema)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.option(
"--type",
type=(
str,
click.Choice(["INTEGER", "TEXT", "FLOAT", "BLOB"], case_sensitive=False),
),
multiple=True,
help="Change column type to INTEGER, TEXT, FLOAT or BLOB",
)
@click.option("--drop", type=str, multiple=True, help="Drop this column")
@click.option(
"--rename", type=(str, str), multiple=True, help="Rename this column to X"
)
@click.option("-o", "--column-order", type=str, multiple=True, help="Reorder columns")
@click.option("--not-null", type=str, multiple=True, help="Set this column to NOT NULL")
@click.option(
"--not-null-false", type=str, multiple=True, help="Remove NOT NULL from this column"
)
@click.option("--pk", type=str, multiple=True, help="Make this column the primary key")
@click.option(
"--pk-none", is_flag=True, help="Remove primary key (convert to rowid table)"
)
@click.option(
"--default",
type=(str, str),
multiple=True,
help="Set default value for this column",
)
@click.option(
"--default-none", type=str, multiple=True, help="Remove default from this column"
)
@click.option(
"add_foreign_keys",
"--add-foreign-key",
type=(str, str, str),
multiple=True,
help="Add a foreign key constraint from a column to another table with another column",
)
@click.option(
"drop_foreign_keys",
"--drop-foreign-key",
type=str,
multiple=True,
help="Drop foreign key constraint for this column",
)
@click.option("--sql", is_flag=True, help="Output SQL without executing it")
@load_extension_option
def transform(
path,
table,
type,
drop,
rename,
column_order,
not_null,
not_null_false,
pk,
pk_none,
default,
default_none,
add_foreign_keys,
drop_foreign_keys,
sql,
load_extension,
):
"""Transform a table beyond the capabilities of ALTER TABLE
Example:
\b
sqlite-utils transform mydb.db mytable \\
--drop column1 \\
--rename column2 column_renamed
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
types = {}
kwargs = {}
for column, ctype in type:
if ctype.upper() not in VALID_COLUMN_TYPES:
raise click.ClickException(
"column types must be one of {}".format(VALID_COLUMN_TYPES)
)
types[column] = ctype.upper()
not_null_dict = {}
for column in not_null:
not_null_dict[column] = True
for column in not_null_false:
not_null_dict[column] = False
default_dict = {}
for column, value in default:
default_dict[column] = value
for column in default_none:
default_dict[column] = None
kwargs["types"] = types
kwargs["drop"] = set(drop)
kwargs["rename"] = dict(rename)
kwargs["column_order"] = column_order or None
kwargs["not_null"] = not_null_dict
if pk:
if len(pk) == 1:
kwargs["pk"] = pk[0]
else:
kwargs["pk"] = pk
elif pk_none:
kwargs["pk"] = None
kwargs["defaults"] = default_dict
if drop_foreign_keys:
kwargs["drop_foreign_keys"] = drop_foreign_keys
if add_foreign_keys:
kwargs["add_foreign_keys"] = add_foreign_keys
if sql:
for line in db[table].transform_sql(**kwargs):
click.echo(line)
else:
db[table].transform(**kwargs)
@cli.command()
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument("columns", nargs=-1, required=True)
@click.option(
"--table", "other_table", help="Name of the other table to extract columns to"
)
@click.option("--fk-column", help="Name of the foreign key column to add to the table")
@click.option(
"--rename",
type=(str, str),
multiple=True,
help="Rename this column in extracted table",
)
@load_extension_option
def extract(
path,
table,
columns,
other_table,
fk_column,
rename,
load_extension,
):
"""Extract one or more columns into a separate table
Example:
\b
sqlite-utils extract trees.db Street_Trees species
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
kwargs = dict(
columns=columns,
table=other_table,
fk_column=fk_column,
rename=dict(rename),
)
db[table].extract(**kwargs)
@cli.command(name="insert-files")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table")
@click.argument(
"file_or_dir",
nargs=-1,
required=True,
type=click.Path(file_okay=True, dir_okay=True, allow_dash=True),
)
@click.option(
"-c",
"--column",
type=str,
multiple=True,
help="Column definitions for the table",
)
@click.option("pks", "--pk", help="Column to use as primary key", multiple=True)
@click.option("--alter", is_flag=True, help="Alter table to add missing columns")
@click.option("--replace", is_flag=True, help="Replace files with matching primary key")
@click.option("--upsert", is_flag=True, help="Upsert files with matching primary key")
@click.option("--name", type=str, help="File name to use")
@click.option("--text", is_flag=True, help="Store file content as TEXT, not BLOB")
@click.option(
"--encoding",
help="Character encoding for input, defaults to utf-8",
)
@click.option("-s", "--silent", is_flag=True, help="Don't show a progress bar")
@load_extension_option
def insert_files(
path,
table,
file_or_dir,
column,
pks,
alter,
replace,
upsert,
name,
text,
encoding,
silent,
load_extension,
):
"""
Insert one or more files using BLOB columns in the specified table
Example:
\b
sqlite-utils insert-files pics.db images *.gif \\
-c name:name \\
-c content:content \\
-c content_hash:sha256 \\
-c created:ctime_iso \\
-c modified:mtime_iso \\
-c size:size \\
--pk name
"""
if not column:
if text:
column = ["path:path", "content_text:content_text", "size:size"]
else:
column = ["path:path", "content:content", "size:size"]
if not pks:
pks = ["path"]
def yield_paths_and_relative_paths():
for f_or_d in file_or_dir:
path = pathlib.Path(f_or_d)
if f_or_d == "-":
yield "-", "-"
elif path.is_dir():
for subpath in path.rglob("*"):
if subpath.is_file():
yield subpath, subpath.relative_to(path)
elif path.is_file():
yield path, path
# Load all paths so we can show a progress bar
paths_and_relative_paths = list(yield_paths_and_relative_paths())
with progressbar(paths_and_relative_paths, silent=silent) as bar:
def to_insert():
for path, relative_path in bar:
row = {}
# content_text is special case as it considers 'encoding'
def _content_text(p):
resolved = p.resolve()
try:
return resolved.read_text(encoding=encoding)
except UnicodeDecodeError as e:
raise UnicodeDecodeErrorForPath(e, resolved)
lookups = dict(FILE_COLUMNS, content_text=_content_text)
if path == "-":
stdin_data = sys.stdin.buffer.read()
# We only support a subset of columns for this case
lookups = {
"name": lambda p: name or "-",
"path": lambda p: name or "-",
"content": lambda p: stdin_data,
"content_text": lambda p: stdin_data.decode(
encoding or "utf-8"
),
"sha256": lambda p: hashlib.sha256(stdin_data).hexdigest(),
"md5": lambda p: hashlib.md5(stdin_data).hexdigest(),
"size": lambda p: len(stdin_data),
}
for coldef in column:
if ":" in coldef:
colname, coltype = coldef.rsplit(":", 1)
else:
colname, coltype = coldef, coldef
try:
value = lookups[coltype](path)
row[colname] = value
except KeyError:
raise click.ClickException(
"'{}' is not a valid column definition - options are {}".format(
coltype, ", ".join(lookups.keys())
)
)
# Special case for --name
if coltype == "name" and name:
row[colname] = name
yield row
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
try:
with db.conn:
db[table].insert_all(
to_insert(),
pk=pks[0] if len(pks) == 1 else pks,
alter=alter,
replace=replace,
upsert=upsert,
)
except UnicodeDecodeErrorForPath as e:
raise click.ClickException(
UNICODE_ERROR.format(
"Could not read file '{}' as text\n\n{}".format(e.path, e.exception)
)
)
@cli.command(name="analyze-tables")
@click.argument(
"path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False, exists=True),
required=True,
)
@click.argument("tables", nargs=-1)
@click.option(
"-c",
"--column",
"columns",
type=str,
multiple=True,
help="Specific columns to analyze",
)
@click.option("--save", is_flag=True, help="Save results to _analyze_tables table")
@click.option("--common-limit", type=int, default=10, help="How many common values")
@click.option("--no-most", is_flag=True, default=False, help="Skip most common values")
@click.option(
"--no-least", is_flag=True, default=False, help="Skip least common values"
)
@load_extension_option
def analyze_tables(
path,
tables,
columns,
save,
common_limit,
no_most,
no_least,
load_extension,
):
"""Analyze the columns in one or more tables
Example:
\b
sqlite-utils analyze-tables data.db trees
"""
db = sqlite_utils.Database(path)
_load_extensions(db, load_extension)
_analyze(db, tables, columns, save, common_limit, no_most, no_least)
def _analyze(db, tables, columns, save, common_limit=10, no_most=False, no_least=False):
if not tables:
tables = db.table_names()
todo = []
table_counts = {}
seen_columns = set()
for table in tables:
table_counts[table] = db[table].count
for column in db[table].columns:
if not columns or column.name in columns:
todo.append((table, column.name))
seen_columns.add(column.name)
# Check the user didn't specify a column that doesn't exist
if columns and (set(columns) - seen_columns):
raise click.ClickException(
"These columns were not found: {}".format(
", ".join(sorted(set(columns) - seen_columns))
)
)
# Now we now how many we need to do
for i, (table, column) in enumerate(todo):
column_details = db[table].analyze_column(
column,
common_limit=common_limit,
total_rows=table_counts[table],
value_truncate=80,
most_common=not no_most,
least_common=not no_least,
)
if save:
db["_analyze_tables_"].insert(
column_details._asdict(), pk=("table", "column"), replace=True
)
most_common_rendered = ""
if column_details.num_null != column_details.total_rows:
most_common_rendered = _render_common(
"\n\n Most common:", column_details.most_common
)
least_common_rendered = _render_common(
"\n\n Least common:", column_details.least_common
)
details = (
(
textwrap.dedent(
"""
{table}.{column}: ({i}/{total})
Total rows: {total_rows}
Null rows: {num_null}
Blank rows: {num_blank}
Distinct values: {num_distinct}{most_common_rendered}{least_common_rendered}
"""
)
.strip()
.format(
i=i + 1,
total=len(todo),
most_common_rendered=most_common_rendered,
least_common_rendered=least_common_rendered,
**column_details._asdict(),
)
)
+ "\n"
)
click.echo(details)
@cli.command()
@click.argument("packages", nargs=-1, required=False)
@click.option(
"-U", "--upgrade", is_flag=True, help="Upgrade packages to latest version"
)
@click.option(
"-e",
"--editable",
help="Install a project in editable mode from this path",
)
def install(packages, upgrade, editable):
"""Install packages from PyPI into the same environment as sqlite-utils"""
args = ["pip", "install"]
if upgrade:
args += ["--upgrade"]
if editable:
args += ["--editable", editable]
args += list(packages)
sys.argv = args
run_module("pip", run_name="__main__")
@cli.command()
@click.argument("packages", nargs=-1, required=True)
@click.option("-y", "--yes", is_flag=True, help="Don't ask for confirmation")
def uninstall(packages, yes):
"""Uninstall Python packages from the sqlite-utils environment"""
sys.argv = ["pip", "uninstall"] + list(packages) + (["-y"] if yes else [])
run_module("pip", run_name="__main__")
def _generate_convert_help():
help = textwrap.dedent(
"""
Convert columns using Python code you supply. For example:
\b
sqlite-utils convert my.db mytable mycolumn \\
'"\\n".join(textwrap.wrap(value, 10))' \\
--import=textwrap
"value" is a variable with the column value to be converted.
Use "-" for CODE to read Python code from standard input.
The following common operations are available as recipe functions:
"""
).strip()
recipe_names = [
n
for n in dir(recipes)
if not n.startswith("_")
and n not in ("json", "parser")
and callable(getattr(recipes, n))
]
for name in recipe_names:
fn = getattr(recipes, name)
help += "\n\nr.{}{}\n\n\b{}".format(
name, str(inspect.signature(fn)), textwrap.dedent(fn.__doc__.rstrip())
)
help += "\n\n"
help += textwrap.dedent(
"""
You can use these recipes like so:
\b
sqlite-utils convert my.db mytable mycolumn \\
'r.jsonsplit(value, delimiter=":")'
"""
).strip()
return help
@cli.command(help=_generate_convert_help())
@click.argument(
"db_path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table", type=str)
@click.argument("columns", type=str, nargs=-1, required=True)
@click.argument("code", type=str)
@click.option(
"--import", "imports", type=str, multiple=True, help="Python modules to import"
)
@click.option(
"--dry-run", is_flag=True, help="Show results of running this against first 10 rows"
)
@click.option(
"--multi", is_flag=True, help="Populate columns for keys in returned dictionary"
)
@click.option("--where", help="Optional where clause")
@click.option(
"-p",
"--param",
multiple=True,
type=(str, str),
help="Named :parameters for where clause",
)
@click.option("--output", help="Optional separate column to populate with the output")
@click.option(
"--output-type",
help="Column type to use for the output column",
default="text",
type=click.Choice(["integer", "float", "blob", "text"]),
)
@click.option("--drop", is_flag=True, help="Drop original column afterwards")
@click.option("--no-skip-false", is_flag=True, help="Don't skip falsey values")
@click.option("-s", "--silent", is_flag=True, help="Don't show a progress bar")
@click.option("pdb_", "--pdb", is_flag=True, help="Open pdb debugger on first error")
def convert(
db_path,
table,
columns,
code,
imports,
dry_run,
multi,
where,
param,
output,
output_type,
drop,
no_skip_false,
silent,
pdb_,
):
sqlite3.enable_callback_tracebacks(True)
db = sqlite_utils.Database(db_path)
if output is not None and len(columns) > 1:
raise click.ClickException("Cannot use --output with more than one column")
if multi and len(columns) > 1:
raise click.ClickException("Cannot use --multi with more than one column")
if drop and not (output or multi):
raise click.ClickException("--drop can only be used with --output or --multi")
if code == "-":
# Read code from standard input
code = sys.stdin.read()
where_args = dict(param) if param else []
# Compile the code into a function body called fn(value)
try:
fn = _compile_code(code, imports)
except SyntaxError as e:
raise click.ClickException(str(e))
if dry_run:
# Pull first 20 values for first column and preview them
if multi:
def preview(v):
return json.dumps(fn(v), default=repr) if v else v
else:
def preview(v):
return fn(v) if v else v
db.conn.create_function("preview_transform", 1, preview)
sql = """
select
[{column}] as value,
preview_transform([{column}]) as preview
from [{table}]{where} limit 10
""".format(
column=columns[0],
table=table,
where=" where {}".format(where) if where is not None else "",
)
for row in db.conn.execute(sql, where_args).fetchall():
click.echo(str(row[0]))
click.echo(" --- becomes:")
click.echo(str(row[1]))
click.echo()
count = db[table].count_where(
where=where,
where_args=where_args,
)
click.echo("Would affect {} row{}".format(count, "" if count == 1 else "s"))
else:
# Wrap fn with a thing that will catch errors and optionally drop into pdb
if pdb_:
fn_ = fn
def wrapped_fn(value):
try:
return fn_(value)
except Exception as ex:
print("\nException raised, dropping into pdb...:", ex)
pdb.post_mortem(ex.__traceback__)
sys.exit(1)
fn = wrapped_fn
try:
db[table].convert(
columns,
fn,
where=where,
where_args=where_args,
output=output,
output_type=output_type,
drop=drop,
skip_false=not no_skip_false,
multi=multi,
show_progress=not silent,
)
except BadMultiValues as e:
raise click.ClickException(
"When using --multi code must return a Python dictionary - returned: {}".format(
repr(e.values)
)
)
@cli.command("add-geometry-column")
@click.argument(
"db_path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table", type=str)
@click.argument("column_name", type=str)
@click.option(
"-t",
"--type",
"geometry_type",
type=click.Choice(
[
"POINT",
"LINESTRING",
"POLYGON",
"MULTIPOINT",
"MULTILINESTRING",
"MULTIPOLYGON",
"GEOMETRYCOLLECTION",
"GEOMETRY",
],
case_sensitive=False,
),
default="GEOMETRY",
help="Specify a geometry type for this column.",
show_default=True,
)
@click.option(
"--srid",
type=int,
default=4326,
show_default=True,
help="Spatial Reference ID. See https://spatialreference.org for details on specific projections.",
)
@click.option(
"--dimensions",
"coord_dimension",
type=str,
default="XY",
help="Coordinate dimensions. Use XYZ for three-dimensional geometries.",
)
@click.option("--not-null", "not_null", is_flag=True, help="Add a NOT NULL constraint.")
@load_extension_option
def add_geometry_column(
db_path,
table,
column_name,
geometry_type,
srid,
coord_dimension,
not_null,
load_extension,
):
"""Add a SpatiaLite geometry column to an existing table. Requires SpatiaLite extension.
\n\n
By default, this command will try to load the SpatiaLite extension from usual paths.
To load it from a specific path, use --load-extension."""
db = sqlite_utils.Database(db_path)
if not db[table].exists():
raise click.ClickException(
"You must create a table before adding a geometry column"
)
# load spatialite, one way or another
if load_extension:
_load_extensions(db, load_extension)
db.init_spatialite()
if db[table].add_geometry_column(
column_name, geometry_type, srid, coord_dimension, not_null
):
click.echo(f"Added {geometry_type} column {column_name} to {table}")
@cli.command("create-spatial-index")
@click.argument(
"db_path",
type=click.Path(file_okay=True, dir_okay=False, allow_dash=False),
required=True,
)
@click.argument("table", type=str)
@click.argument("column_name", type=str)
@load_extension_option
def create_spatial_index(db_path, table, column_name, load_extension):
"""Create a spatial index on a SpatiaLite geometry column.
The table and geometry column must already exist before trying to add a spatial index.
\n\n
By default, this command will try to load the SpatiaLite extension from usual paths.
To load it from a specific path, use --load-extension."""
db = sqlite_utils.Database(db_path)
if not db[table].exists():
raise click.ClickException(
"You must create a table and add a geometry column before creating a spatial index"
)
# load spatialite
if load_extension:
_load_extensions(db, load_extension)
db.init_spatialite()
if column_name not in db[table].columns_dict:
raise click.ClickException(
"You must add a geometry column before creating a spatial index"
)
db[table].create_spatial_index(column_name)
@cli.command(name="plugins")
def plugins_list():
"List installed plugins"
click.echo(json.dumps(get_plugins(), indent=2))
pm.hook.register_commands(cli=cli)
def _render_common(title, values):
if values is None:
return ""
lines = [title]
for value, count in values:
lines.append(" {}: {}".format(count, value))
return "\n".join(lines)
class UnicodeDecodeErrorForPath(Exception):
def __init__(self, exception, path):
self.exception = exception
self.path = path
FILE_COLUMNS = {
"name": lambda p: p.name,
"path": lambda p: str(p),
"fullpath": lambda p: str(p.resolve()),
"sha256": lambda p: hashlib.sha256(p.resolve().read_bytes()).hexdigest(),
"md5": lambda p: hashlib.md5(p.resolve().read_bytes()).hexdigest(),
"mode": lambda p: p.stat().st_mode,
"content": lambda p: p.resolve().read_bytes(),
"mtime": lambda p: p.stat().st_mtime,
"ctime": lambda p: p.stat().st_ctime,
"mtime_int": lambda p: int(p.stat().st_mtime),
"ctime_int": lambda p: int(p.stat().st_ctime),
"mtime_iso": lambda p: datetime.utcfromtimestamp(p.stat().st_mtime).isoformat(),
"ctime_iso": lambda p: datetime.utcfromtimestamp(p.stat().st_ctime).isoformat(),
"size": lambda p: p.stat().st_size,
"stem": lambda p: p.stem,
"suffix": lambda p: p.suffix,
}
def output_rows(iterator, headers, nl, arrays, json_cols):
# We have to iterate two-at-a-time so we can know if we
# should output a trailing comma or if we have reached
# the last row.
current_iter, next_iter = itertools.tee(iterator, 2)
next(next_iter, None)
first = True
for row, next_row in itertools.zip_longest(current_iter, next_iter):
is_last = next_row is None
data = row
if json_cols:
# Any value that is a valid JSON string should be treated as JSON
data = [maybe_json(value) for value in data]
if not arrays:
data = dict(zip(headers, data))
line = "{firstchar}{serialized}{maybecomma}{lastchar}".format(
firstchar=("[" if first else " ") if not nl else "",
serialized=json.dumps(data, default=json_binary),
maybecomma="," if (not nl and not is_last) else "",
lastchar="]" if (is_last and not nl) else "",
)
yield line
first = False
if first:
# We didn't output any rows, so yield the empty list
yield "[]"
def maybe_json(value):
if not isinstance(value, str):
return value
stripped = value.strip()
if not (stripped.startswith("{") or stripped.startswith("[")):
return value
try:
return json.loads(stripped)
except ValueError:
return value
def json_binary(value):
if isinstance(value, bytes):
return {"$base64": True, "encoded": base64.b64encode(value).decode("latin-1")}
else:
raise TypeError
def verify_is_dict(doc):
if not isinstance(doc, dict):
raise click.ClickException(
"Rows must all be dictionaries, got: {}".format(repr(doc)[:1000])
)
return doc
def _load_extensions(db, load_extension):
if load_extension:
db.conn.enable_load_extension(True)
for ext in load_extension:
if ext == "spatialite" and not os.path.exists(ext):
ext = find_spatialite()
if ":" in ext:
path, _, entrypoint = ext.partition(":")
db.conn.execute("SELECT load_extension(?, ?)", [path, entrypoint])
else:
db.conn.load_extension(ext)
def _register_functions(db, functions):
# Register any Python functions as SQL functions:
sqlite3.enable_callback_tracebacks(True)
globals = {}
try:
exec(functions, globals)
except SyntaxError as ex:
raise click.ClickException("Error in functions definition: {}".format(ex))
# Register all callables in the locals dict:
for name, value in globals.items():
if callable(value) and not name.startswith("_"):
db.register_function(value, name=name)
</document_content>
</document>
<document index="6">
<source>./sqlite_utils/conversions.py</source>
<document_content>
from abc import ABC, abstractmethod
from typing import Any, List
class Conversion(ABC):
@abstractmethod
def __init__(self, *args) -> None:
"Constructor should only take positional arguments"
@abstractmethod
def params(self) -> List[Any]:
"One or more strings to be used for ? in the SQL fragment"
@abstractmethod
def sql(self) -> str:
"SQL fragment to use, with ? placeholders for params"
class LatitudeLongitude(Conversion):
def __init__(self, latitude, longitude):
self.latitude = latitude
self.longitude = longitude
def params(self):
return ["POINT({}, {})".format(self.longitude, self.latitude)]
def sql(self):
return "GeomFromText(?, 4326)"
class LongitudeLatitude(LatitudeLongitude):
def __init__(self, longitude, latitude):
super().__init__(latitude, longitude)
</document_content>
</document>
<document index="7">
<source>./sqlite_utils/db.py</source>
<document_content>
from .utils import (
chunks,
hash_record,
sqlite3,
OperationalError,
suggest_column_types,
types_for_column_types,
column_affinity,
progressbar,
find_spatialite,
)
import binascii
from collections import namedtuple
from collections.abc import Mapping
import contextlib
import datetime
import decimal
import inspect
import itertools
import json
import os
import pathlib
import re
import secrets
from sqlite_fts4 import rank_bm25 # type: ignore
import textwrap
from typing import (
cast,
Any,
Callable,
Dict,
Generator,
Iterable,
Union,
Optional,
List,
Tuple,
)
import uuid
from sqlite_utils.plugins import pm
try:
from sqlite_dump import iterdump
except ImportError:
iterdump = None
SQLITE_MAX_VARS = 999
_quote_fts_re = re.compile(r'\s+|(".*?")')
_virtual_table_using_re = re.compile(
r"""
^ # Start of string
\s*CREATE\s+VIRTUAL\s+TABLE\s+ # CREATE VIRTUAL TABLE
(
'(?P<squoted_table>[^']*(?:''[^']*)*)' | # single quoted name
"(?P<dquoted_table>[^"]*(?:""[^"]*)*)" | # double quoted name
`(?P<backtick_table>[^`]+)` | # `backtick` quoted name
\[(?P<squarequoted_table>[^\]]+)\] | # [...] quoted name
(?P<identifier> # SQLite non-quoted identifier
[A-Za-z_\u0080-\uffff] # \u0080-\uffff = "any character larger than u007f"
[A-Za-z_\u0080-\uffff0-9\$]* # zero-or-more alphanemuric or $
)
)
\s+(IF\s+NOT\s+EXISTS\s+)? # IF NOT EXISTS (optional)
USING\s+(?P<using>\w+) # for example USING FTS5
""",
re.VERBOSE | re.IGNORECASE,
)
try:
import pandas as pd # type: ignore
except ImportError:
pd = None # type: ignore
try:
import numpy as np # type: ignore
except ImportError:
np = None # type: ignore
Column = namedtuple(
"Column", ("cid", "name", "type", "notnull", "default_value", "is_pk")
)
Column.__doc__ = """
Describes a SQLite column returned by the :attr:`.Table.columns` property.
``cid``
Column index
``name``
Column name
``type``
Column type
``notnull``
Does the column have a ``not null`` constraint
``default_value``
Default value for this column
``is_pk``
Is this column part of the primary key
"""
ColumnDetails = namedtuple(
"ColumnDetails",
(
"table",
"column",
"total_rows",
"num_null",
"num_blank",
"num_distinct",
"most_common",
"least_common",
),
)
ColumnDetails.__doc__ = """
Summary information about a column, see :ref:`python_api_analyze_column`.
``table``
The name of the table
``column``
The name of the column
``total_rows``
The total number of rows in the table
``num_null``
The number of rows for which this column is null
``num_blank``
The number of rows for which this column is blank (the empty string)
``num_distinct``
The number of distinct values in this column
``most_common``
The ``N`` most common values as a list of ``(value, count)`` tuples, or ``None`` if the table consists entirely of distinct values
``least_common``
The ``N`` least common values as a list of ``(value, count)`` tuples, or ``None`` if the table is entirely distinct
or if the number of distinct values is less than N (since they will already have been returned in ``most_common``)
"""
ForeignKey = namedtuple(
"ForeignKey", ("table", "column", "other_table", "other_column")
)
Index = namedtuple("Index", ("seq", "name", "unique", "origin", "partial", "columns"))
XIndex = namedtuple("XIndex", ("name", "columns"))
XIndexColumn = namedtuple(
"XIndexColumn", ("seqno", "cid", "name", "desc", "coll", "key")
)
Trigger = namedtuple("Trigger", ("name", "table", "sql"))
class TransformError(Exception):
pass
ForeignKeyIndicator = Union[
str,
ForeignKey,
Tuple[str, str],
Tuple[str, str, str],
Tuple[str, str, str, str],
]
ForeignKeysType = Union[Iterable[ForeignKeyIndicator], List[ForeignKeyIndicator]]
class Default:
pass
DEFAULT = Default()
COLUMN_TYPE_MAPPING = {
float: "FLOAT",
int: "INTEGER",
bool: "INTEGER",
str: "TEXT",
dict: "TEXT",
tuple: "TEXT",
list: "TEXT",
bytes.__class__: "BLOB",
bytes: "BLOB",
memoryview: "BLOB",
datetime.datetime: "TEXT",
datetime.date: "TEXT",
datetime.time: "TEXT",
datetime.timedelta: "TEXT",
decimal.Decimal: "FLOAT",
None.__class__: "TEXT",
uuid.UUID: "TEXT",
# SQLite explicit types
"TEXT": "TEXT",
"INTEGER": "INTEGER",
"FLOAT": "FLOAT",
"BLOB": "BLOB",
"text": "TEXT",
"str": "TEXT",
"integer": "INTEGER",
"int": "INTEGER",
"float": "FLOAT",
"blob": "BLOB",
"bytes": "BLOB",
}
# If numpy is available, add more types
if np:
try:
COLUMN_TYPE_MAPPING.update(
{
np.int8: "INTEGER",
np.int16: "INTEGER",
np.int32: "INTEGER",
np.int64: "INTEGER",
np.uint8: "INTEGER",
np.uint16: "INTEGER",
np.uint32: "INTEGER",
np.uint64: "INTEGER",
np.float16: "FLOAT",
np.float32: "FLOAT",
np.float64: "FLOAT",
}
)
except AttributeError:
# https://github.com/simonw/sqlite-utils/issues/632
pass
# If pandas is available, add more types
if pd:
COLUMN_TYPE_MAPPING.update({pd.Timestamp: "TEXT"}) # type: ignore
class AlterError(Exception):
"Error altering table"
class NoObviousTable(Exception):
"Could not tell which table this operation refers to"
class NoTable(Exception):
"Specified table does not exist"
class BadPrimaryKey(Exception):
"Table does not have a single obvious primary key"
class NotFoundError(Exception):
"Record not found"
class PrimaryKeyRequired(Exception):
"Primary key needs to be specified"
class InvalidColumns(Exception):
"Specified columns do not exist"
class DescIndex(str):
pass
class BadMultiValues(Exception):
"With multi=True code must return a Python dictionary"
def __init__(self, values):
self.values = values
_COUNTS_TABLE_CREATE_SQL = """
CREATE TABLE IF NOT EXISTS [{}](
[table] TEXT PRIMARY KEY,
count INTEGER DEFAULT 0
);
""".strip()
class Database:
"""
Wrapper for a SQLite database connection that adds a variety of useful utility methods.
To create an instance::
# create data.db file, or open existing:
db = Database("data.db")
# Create an in-memory database:
dB = Database(memory=True)
:param filename_or_conn: String path to a file, or a ``pathlib.Path`` object, or a
``sqlite3`` connection
:param memory: set to ``True`` to create an in-memory database
:param memory_name: creates a named in-memory database that can be shared across multiple connections
:param recreate: set to ``True`` to delete and recreate a file database (**dangerous**)
:param recursive_triggers: defaults to ``True``, which sets ``PRAGMA recursive_triggers=on;`` -
set to ``False`` to avoid setting this pragma
:param tracer: set a tracer function (``print`` works for this) which will be called with
``sql, parameters`` every time a SQL query is executed
:param use_counts_table: set to ``True`` to use a cached counts table, if available. See
:ref:`python_api_cached_table_counts`
:param use_old_upsert: set to ``True`` to force the older upsert implementation. See
:ref:`python_api_old_upsert`
:param strict: Apply STRICT mode to all created tables (unless overridden)
"""
_counts_table_name = "_counts"
use_counts_table = False
def __init__(
self,
filename_or_conn: Optional[Union[str, pathlib.Path, sqlite3.Connection]] = None,
memory: bool = False,
memory_name: Optional[str] = None,
recreate: bool = False,
recursive_triggers: bool = True,
tracer: Optional[Callable] = None,
use_counts_table: bool = False,
execute_plugins: bool = True,
use_old_upsert: bool = False,
strict: bool = False,
):
self.memory_name = None
self.memory = False
self.use_old_upsert = use_old_upsert
assert (filename_or_conn is not None and (not memory and not memory_name)) or (
filename_or_conn is None and (memory or memory_name)
), "Either specify a filename_or_conn or pass memory=True"
if memory_name:
uri = "file:{}?mode=memory&cache=shared".format(memory_name)
self.conn = sqlite3.connect(
uri,
uri=True,
check_same_thread=False,
)
self.memory = True
self.memory_name = memory_name
elif memory or filename_or_conn == ":memory:":
self.conn = sqlite3.connect(":memory:")
self.memory = True
elif isinstance(filename_or_conn, (str, pathlib.Path)):
if recreate and os.path.exists(filename_or_conn):
try:
os.remove(filename_or_conn)
except OSError:
# Avoid mypy and __repr__ errors, see:
# https://github.com/simonw/sqlite-utils/issues/503
self.conn = sqlite3.connect(":memory:")
raise
self.conn = sqlite3.connect(str(filename_or_conn))
else:
assert not recreate, "recreate cannot be used with connections, only paths"
self.conn = filename_or_conn
self._tracer = tracer
if recursive_triggers:
self.execute("PRAGMA recursive_triggers=on;")
self._registered_functions: set = set()
self.use_counts_table = use_counts_table
if execute_plugins:
pm.hook.prepare_connection(conn=self.conn)
self.strict = strict
def close(self):
"Close the SQLite connection, and the underlying database file"
self.conn.close()
@contextlib.contextmanager
def ensure_autocommit_off(self):
"""
Ensure autocommit is off for this database connection.
Example usage::
with db.ensure_autocommit_off():
# do stuff here
This will reset to the previous autocommit state at the end of the block.
"""
old_isolation_level = self.conn.isolation_level
try:
self.conn.isolation_level = None
yield
finally:
self.conn.isolation_level = old_isolation_level
@contextlib.contextmanager
def tracer(self, tracer: Optional[Callable] = None):
"""
Context manager to temporarily set a tracer function - all executed SQL queries will
be passed to this.
The tracer function should accept two arguments: ``sql`` and ``parameters``
Example usage::
with db.tracer(print):
db["creatures"].insert({"name": "Cleo"})
See :ref:`python_api_tracing`.
:param tracer: Callable accepting ``sql`` and ``parameters`` arguments
"""
prev_tracer = self._tracer
self._tracer = tracer or print
try:
yield self
finally:
self._tracer = prev_tracer
def __getitem__(self, table_name: str) -> Union["Table", "View"]:
"""
``db[table_name]`` returns a :class:`.Table` object for the table with the specified name.
If the table does not exist yet it will be created the first time data is inserted into it.
:param table_name: The name of the table
"""
return self.table(table_name)
def __repr__(self) -> str:
return "<Database {}>".format(self.conn)
def register_function(
self,
fn: Optional[Callable] = None,
deterministic: bool = False,
replace: bool = False,
name: Optional[str] = None,
):
"""
``fn`` will be made available as a function within SQL, with the same name and number
of arguments. Can be used as a decorator::
@db.register_function
def upper(value):
return str(value).upper()
The decorator can take arguments::
@db.register_function(deterministic=True, replace=True)
def upper(value):
return str(value).upper()
See :ref:`python_api_register_function`.
:param fn: Function to register
:param deterministic: set ``True`` for functions that always returns the same output for a given input
:param replace: set ``True`` to replace an existing function with the same name - otherwise throw an error
:param name: name of the SQLite function - if not specified, the Python function name will be used
"""
def register(fn):
fn_name = name or fn.__name__
arity = len(inspect.signature(fn).parameters)
if not replace and (fn_name, arity) in self._registered_functions:
return fn
kwargs = {}
registered = False
if deterministic:
# Try this, but fall back if sqlite3.NotSupportedError
try:
self.conn.create_function(
fn_name, arity, fn, **dict(kwargs, deterministic=True)
)
registered = True
except sqlite3.NotSupportedError:
pass
if not registered:
self.conn.create_function(fn_name, arity, fn, **kwargs)
self._registered_functions.add((fn_name, arity))
return fn
if fn is None:
return register
else:
register(fn)
def register_fts4_bm25(self):
"Register the ``rank_bm25(match_info)`` function used for calculating relevance with SQLite FTS4."
self.register_function(rank_bm25, deterministic=True, replace=True)
def attach(self, alias: str, filepath: Union[str, pathlib.Path]):
"""
Attach another SQLite database file to this connection with the specified alias, equivalent to::
ATTACH DATABASE 'filepath.db' AS alias
:param alias: Alias name to use
:param filepath: Path to SQLite database file on disk
"""
attach_sql = """
ATTACH DATABASE '{}' AS [{}];
""".format(
str(pathlib.Path(filepath).resolve()), alias
).strip()
self.execute(attach_sql)
def query(
self, sql: str, params: Optional[Union[Iterable, dict]] = None
) -> Generator[dict, None, None]:
"""
Execute ``sql`` and return an iterable of dictionaries representing each row.
:param sql: SQL query to execute
:param params: Parameters to use in that query - an iterable for ``where id = ?``
parameters, or a dictionary for ``where id = :id``
"""
cursor = self.execute(sql, params or tuple())
keys = [d[0] for d in cursor.description]
for row in cursor:
yield dict(zip(keys, row))
def execute(
self, sql: str, parameters: Optional[Union[Iterable, dict]] = None
) -> sqlite3.Cursor:
"""
Execute SQL query and return a ``sqlite3.Cursor``.
:param sql: SQL query to execute
:param parameters: Parameters to use in that query - an iterable for ``where id = ?``
parameters, or a dictionary for ``where id = :id``
"""
if self._tracer:
self._tracer(sql, parameters)
if parameters is not None:
return self.conn.execute(sql, parameters)
else:
return self.conn.execute(sql)
def executescript(self, sql: str) -> sqlite3.Cursor:
"""
Execute multiple SQL statements separated by ; and return the ``sqlite3.Cursor``.
:param sql: SQL to execute
"""
if self._tracer:
self._tracer(sql, None)
return self.conn.executescript(sql)
def table(self, table_name: str, **kwargs) -> Union["Table", "View"]:
"""
Return a table object, optionally configured with default options.
See :ref:`reference_db_table` for option details.
:param table_name: Name of the table
"""
if table_name in self.view_names():
return View(self, table_name, **kwargs)
else:
kwargs.setdefault("strict", self.strict)
return Table(self, table_name, **kwargs)
def quote(self, value: str) -> str:
"""
Apply SQLite string quoting to a value, including wrapping it in single quotes.
:param value: String to quote
"""
# Normally we would use .execute(sql, [params]) for escaping, but
# occasionally that isn't available - most notable when we need
# to include a "... DEFAULT 'value'" in a column definition.
return self.execute(
# Use SQLite itself to correctly escape this string:
"SELECT quote(:value)",
{"value": value},
).fetchone()[0]
def quote_fts(self, query: str) -> str:
"""
Escape special characters in a SQLite full-text search query.
This works by surrounding each token within the query with double
quotes, in order to avoid words like ``NOT`` and ``OR`` having
special meaning as defined by the FTS query syntax here:
https://www.sqlite.org/fts5.html#full_text_query_syntax
If the query has unbalanced ``"`` characters, adds one at end.
:param query: String to escape
"""
if query.count('"') % 2:
query += '"'
bits = _quote_fts_re.split(query)
bits = [b for b in bits if b and b != '""']
return " ".join(
'"{}"'.format(bit) if not bit.startswith('"') else bit for bit in bits
)
def quote_default_value(self, value: str) -> str:
if any(
[
str(value).startswith("'") and str(value).endswith("'"),
str(value).startswith('"') and str(value).endswith('"'),
]
):
return value
if str(value).upper() in ("CURRENT_TIME", "CURRENT_DATE", "CURRENT_TIMESTAMP"):
return value
if str(value).endswith(")"):
# Expr
return "({})".format(value)
return self.quote(value)
def table_names(self, fts4: bool = False, fts5: bool = False) -> List[str]:
"""
List of string table names in this database.
:param fts4: Only return tables that are part of FTS4 indexes
:param fts5: Only return tables that are part of FTS5 indexes
"""
where = ["type = 'table'"]
if fts4:
where.append("sql like '%USING FTS4%'")
if fts5:
where.append("sql like '%USING FTS5%'")
sql = "select name from sqlite_master where {}".format(" AND ".join(where))
return [r[0] for r in self.execute(sql).fetchall()]
def view_names(self) -> List[str]:
"List of string view names in this database."
return [
r[0]
for r in self.execute(
"select name from sqlite_master where type = 'view'"
).fetchall()
]
@property
def tables(self) -> List["Table"]:
"List of Table objects in this database."
return cast(List["Table"], [self[name] for name in self.table_names()])
@property
def views(self) -> List["View"]:
"List of View objects in this database."
return cast(List["View"], [self[name] for name in self.view_names()])
@property
def triggers(self) -> List[Trigger]:
"List of ``(name, table_name, sql)`` tuples representing triggers in this database."
return [
Trigger(*r)
for r in self.execute(
"select name, tbl_name, sql from sqlite_master where type = 'trigger'"
).fetchall()
]
@property
def triggers_dict(self) -> Dict[str, str]:
"A ``{trigger_name: sql}`` dictionary of triggers in this database."
return {trigger.name: trigger.sql for trigger in self.triggers}
@property
def schema(self) -> str:
"SQL schema for this database."
sqls = []
for row in self.execute(
"select sql from sqlite_master where sql is not null"
).fetchall():
sql = row[0]
if not sql.strip().endswith(";"):
sql += ";"
sqls.append(sql)
return "\n".join(sqls)
@property
def supports_strict(self) -> bool:
"Does this database support STRICT mode?"
if not hasattr(self, "_supports_strict"):
try:
table_name = "t{}".format(secrets.token_hex(16))
with self.conn:
self.conn.execute(
"create table {} (name text) strict".format(table_name)
)
self.conn.execute("drop table {}".format(table_name))
self._supports_strict = True
except Exception:
self._supports_strict = False
return self._supports_strict
@property
def supports_on_conflict(self) -> bool:
# SQLite's upsert is implemented as INSERT INTO ... ON CONFLICT DO ...
if not hasattr(self, "_supports_on_conflict"):
try:
table_name = "t{}".format(secrets.token_hex(16))
with self.conn:
self.conn.execute(
"create table {} (id integer primary key, name text)".format(
table_name
)
)
self.conn.execute(
"insert into {} (id, name) values (1, 'one')".format(table_name)
)
self.conn.execute(
(
"insert into {} (id, name) values (1, 'two') "
"on conflict do update set name = 'two'"
).format(table_name)
)
self.conn.execute("drop table {}".format(table_name))
self._supports_on_conflict = True
except Exception:
self._supports_on_conflict = False
return self._supports_on_conflict
@property
def sqlite_version(self) -> Tuple[int, ...]:
"Version of SQLite, as a tuple of integers for example ``(3, 36, 0)``."
row = self.execute("select sqlite_version()").fetchall()[0]
return tuple(map(int, row[0].split(".")))
@property
def journal_mode(self) -> str:
"""
Current ``journal_mode`` of this database.
https://www.sqlite.org/pragma.html#pragma_journal_mode
"""
return self.execute("PRAGMA journal_mode;").fetchone()[0]
def enable_wal(self):
"""
Sets ``journal_mode`` to ``'wal'`` to enable Write-Ahead Log mode.
"""
if self.journal_mode != "wal":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=wal;")
def disable_wal(self):
"Sets ``journal_mode`` back to ``'delete'`` to disable Write-Ahead Log mode."
if self.journal_mode != "delete":
with self.ensure_autocommit_off():
self.execute("PRAGMA journal_mode=delete;")
def _ensure_counts_table(self):
with self.conn:
self.execute(_COUNTS_TABLE_CREATE_SQL.format(self._counts_table_name))
def enable_counts(self):
"""
Enable trigger-based count caching for every table in the database, see
:ref:`python_api_cached_table_counts`.
"""
self._ensure_counts_table()
for table in self.tables:
if (
table.virtual_table_using is None
and table.name != self._counts_table_name
):
table.enable_counts()
self.use_counts_table = True
def cached_counts(self, tables: Optional[Iterable[str]] = None) -> Dict[str, int]:
"""
Return ``{table_name: count}`` dictionary of cached counts for specified tables, or
all tables if ``tables`` not provided.
:param tables: Subset list of tables to return counts for.
"""
sql = "select [table], count from {}".format(self._counts_table_name)
if tables:
sql += " where [table] in ({})".format(", ".join("?" for table in tables))
try:
return {r[0]: r[1] for r in self.execute(sql, tables).fetchall()}
except OperationalError:
return {}
def reset_counts(self):
"Re-calculate cached counts for tables."
tables = [table for table in self.tables if table.has_counts_triggers]
with self.conn:
self._ensure_counts_table()
counts_table = self[self._counts_table_name]
counts_table.delete_where()
counts_table.insert_all(
{"table": table.name, "count": table.execute_count()}
for table in tables
)
def execute_returning_dicts(
self, sql: str, params: Optional[Union[Iterable, dict]] = None
) -> List[dict]:
return list(self.query(sql, params))
def resolve_foreign_keys(
self, name: str, foreign_keys: ForeignKeysType
) -> List[ForeignKey]:
"""
Given a list of differing foreign_keys definitions, return a list of
fully resolved ForeignKey() named tuples.
:param name: Name of table that foreign keys are being defined for
:param foreign_keys: List of foreign keys, each of which can be a
string, a ForeignKey() named tuple, a tuple of (column, other_table),
or a tuple of (column, other_table, other_column), or a tuple of
(table, column, other_table, other_column)
"""
table = cast(Table, self[name])
if all(isinstance(fk, ForeignKey) for fk in foreign_keys):
return cast(List[ForeignKey], foreign_keys)
if all(isinstance(fk, str) for fk in foreign_keys):
# It's a list of columns
fks = []
for column in foreign_keys:
column = cast(str, column)
other_table = table.guess_foreign_table(column)
other_column = table.guess_foreign_column(other_table)
fks.append(ForeignKey(name, column, other_table, other_column))
return fks
assert all(
isinstance(fk, (tuple, list)) for fk in foreign_keys
), "foreign_keys= should be a list of tuples"
fks = []
for tuple_or_list in foreign_keys:
if len(tuple_or_list) == 4:
assert (
tuple_or_list[0] == name
), "First item in {} should have been {}".format(tuple_or_list, name)
assert len(tuple_or_list) in (
2,
3,
4,
), "foreign_keys= should be a list of tuple pairs or triples"
if len(tuple_or_list) in (3, 4):
if len(tuple_or_list) == 4:
tuple_or_list = cast(Tuple[str, str, str], tuple_or_list[1:])
else:
tuple_or_list = cast(Tuple[str, str, str], tuple_or_list)
fks.append(
ForeignKey(
name, tuple_or_list[0], tuple_or_list[1], tuple_or_list[2]
)
)
else:
# Guess the primary key
fks.append(
ForeignKey(
name,
tuple_or_list[0],
tuple_or_list[1],
table.guess_foreign_column(tuple_or_list[1]),
)
)
return fks
def create_table_sql(
self,
name: str,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[str] = None,
hash_id_columns: Optional[Iterable[str]] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
if_not_exists: bool = False,
strict: bool = False,
) -> str:
"""
Returns the SQL ``CREATE TABLE`` statement for creating the specified table.
:param name: Name of table
:param columns: Dictionary mapping column names to their types, for example ``{"name": str, "age": int}``
:param pk: String name of column to use as a primary key, or a tuple of strings for a compound primary key covering multiple columns
:param foreign_keys: List of foreign key definitions for this table
:param column_order: List specifying which columns should come first
:param not_null: List of columns that should be created as ``NOT NULL``
:param defaults: Dictionary specifying default values for columns
:param hash_id: Name of column to be used as a primary key containing a hash of the other columns
:param hash_id_columns: List of columns to be used when calculating the hash ID for a row
:param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts`
:param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS``
:param strict: Apply STRICT mode to table
"""
if hash_id_columns and (hash_id is None):
hash_id = "id"
foreign_keys = self.resolve_foreign_keys(name, foreign_keys or [])
foreign_keys_by_column = {fk.column: fk for fk in foreign_keys}
# any extracts will be treated as integer columns with a foreign key
extracts = resolve_extracts(extracts)
for extract_column, extract_table in extracts.items():
if isinstance(extract_column, tuple):
assert False
# Ensure other table exists
if not self[extract_table].exists():
self.create_table(extract_table, {"id": int, "value": str}, pk="id")
columns[extract_column] = int
foreign_keys_by_column[extract_column] = ForeignKey(
name, extract_column, extract_table, "id"
)
# Soundness check not_null, and defaults if provided
not_null = not_null or set()
defaults = defaults or {}
assert columns, "Tables must have at least one column"
assert all(
n in columns for n in not_null
), "not_null set {} includes items not in columns {}".format(
repr(not_null), repr(set(columns.keys()))
)
assert all(
n in columns for n in defaults
), "defaults set {} includes items not in columns {}".format(
repr(set(defaults)), repr(set(columns.keys()))
)
validate_column_names(columns.keys())
column_items = list(columns.items())
if column_order is not None:
def sort_key(p):
return column_order.index(p[0]) if p[0] in column_order else 999
column_items.sort(key=sort_key)
if hash_id:
column_items.insert(0, (hash_id, str))
pk = hash_id
# Soundness check foreign_keys point to existing tables
for fk in foreign_keys:
if fk.other_table == name and columns.get(fk.other_column):
continue
if fk.other_column != "rowid" and not any(
c for c in self[fk.other_table].columns if c.name == fk.other_column
):
raise AlterError(
"No such column: {}.{}".format(fk.other_table, fk.other_column)
)
column_defs = []
# ensure pk is a tuple
single_pk = None
if isinstance(pk, list) and len(pk) == 1 and isinstance(pk[0], str):
pk = pk[0]
if isinstance(pk, str):
single_pk = pk
if pk not in [c[0] for c in column_items]:
column_items.insert(0, (pk, int))
for column_name, column_type in column_items:
column_extras = []
if column_name == single_pk:
column_extras.append("PRIMARY KEY")
if column_name in not_null:
column_extras.append("NOT NULL")
if column_name in defaults and defaults[column_name] is not None:
column_extras.append(
"DEFAULT {}".format(self.quote_default_value(defaults[column_name]))
)
if column_name in foreign_keys_by_column:
column_extras.append(
"REFERENCES [{other_table}]([{other_column}])".format(
other_table=foreign_keys_by_column[column_name].other_table,
other_column=foreign_keys_by_column[column_name].other_column,
)
)
column_type_str = COLUMN_TYPE_MAPPING[column_type]
# Special case for strict tables to map FLOAT to REAL
# Refs https://github.com/simonw/sqlite-utils/issues/644
if strict and column_type_str == "FLOAT":
column_type_str = "REAL"
column_defs.append(
" [{column_name}] {column_type}{column_extras}".format(
column_name=column_name,
column_type=column_type_str,
column_extras=(
(" " + " ".join(column_extras)) if column_extras else ""
),
)
)
extra_pk = ""
if single_pk is None and pk and len(pk) > 1:
extra_pk = ",\n PRIMARY KEY ({pks})".format(
pks=", ".join(["[{}]".format(p) for p in pk])
)
columns_sql = ",\n".join(column_defs)
sql = """CREATE TABLE {if_not_exists}[{table}] (
{columns_sql}{extra_pk}
){strict};
""".format(
if_not_exists="IF NOT EXISTS " if if_not_exists else "",
table=name,
columns_sql=columns_sql,
extra_pk=extra_pk,
strict=" STRICT" if strict and self.supports_strict else "",
)
return sql
def create_table(
self,
name: str,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[str] = None,
hash_id_columns: Optional[Iterable[str]] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
if_not_exists: bool = False,
replace: bool = False,
ignore: bool = False,
transform: bool = False,
strict: bool = False,
) -> "Table":
"""
Create a table with the specified name and the specified ``{column_name: type}`` columns.
See :ref:`python_api_explicit_create`.
:param name: Name of table
:param columns: Dictionary mapping column names to their types, for example ``{"name": str, "age": int}``
:param pk: String name of column to use as a primary key, or a tuple of strings for a compound primary key covering multiple columns
:param foreign_keys: List of foreign key definitions for this table
:param column_order: List specifying which columns should come first
:param not_null: List of columns that should be created as ``NOT NULL``
:param defaults: Dictionary specifying default values for columns
:param hash_id: Name of column to be used as a primary key containing a hash of the other columns
:param hash_id_columns: List of columns to be used when calculating the hash ID for a row
:param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts`
:param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS``
:param replace: Drop and replace table if it already exists
:param ignore: Silently do nothing if table already exists
:param transform: If table already exists transform it to fit the specified schema
:param strict: Apply STRICT mode to table
"""
# Transform table to match the new definition if table already exists:
if self[name].exists():
if ignore:
return cast(Table, self[name])
elif replace:
self[name].drop()
if transform and self[name].exists():
table = cast(Table, self[name])
should_transform = False
# First add missing columns and figure out columns to drop
existing_columns = table.columns_dict
missing_columns = dict(
(col_name, col_type)
for col_name, col_type in columns.items()
if col_name not in existing_columns
)
columns_to_drop = [
column for column in existing_columns if column not in columns
]
if missing_columns:
for col_name, col_type in missing_columns.items():
table.add_column(col_name, col_type)
if missing_columns or columns_to_drop or columns != existing_columns:
should_transform = True
# Do we need to change the column order?
if (
column_order
and list(existing_columns)[: len(column_order)] != column_order
):
should_transform = True
# Has the primary key changed?
current_pks = table.pks
desired_pk = None
if isinstance(pk, str):
desired_pk = [pk]
elif pk:
desired_pk = list(pk)
if desired_pk and current_pks != desired_pk:
should_transform = True
# Any not-null changes?
current_not_null = {c.name for c in table.columns if c.notnull}
desired_not_null = set(not_null) if not_null else set()
if current_not_null != desired_not_null:
should_transform = True
# How about defaults?
if defaults and defaults != table.default_values:
should_transform = True
# Only run .transform() if there is something to do
if should_transform:
table.transform(
types=columns,
drop=columns_to_drop,
column_order=column_order,
not_null=not_null,
defaults=defaults,
pk=pk,
)
return table
sql = self.create_table_sql(
name=name,
columns=columns,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
extracts=extracts,
if_not_exists=if_not_exists,
strict=strict,
)
self.execute(sql)
created_table = self.table(
name,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
)
return cast(Table, created_table)
def rename_table(self, name: str, new_name: str):
"""
Rename a table.
:param name: Current table name
:param new_name: Name to rename it to
"""
self.execute(
"ALTER TABLE [{name}] RENAME TO [{new_name}]".format(
name=name, new_name=new_name
)
)
def create_view(
self, name: str, sql: str, ignore: bool = False, replace: bool = False
):
"""
Create a new SQL view with the specified name - ``sql`` should start with ``SELECT ...``.
:param name: Name of the view
:param sql: SQL ``SELECT`` query to use for this view.
:param ignore: Set to ``True`` to do nothing if a view with this name already exists
:param replace: Set to ``True`` to replace the view if one with this name already exists
"""
assert not (
ignore and replace
), "Use one or the other of ignore/replace, not both"
create_sql = "CREATE VIEW {name} AS {sql}".format(name=name, sql=sql)
if ignore or replace:
# Does view exist already?
if name in self.view_names():
if ignore:
return self
elif replace:
# If SQL is the same, do nothing
if create_sql == self[name].schema:
return self
self[name].drop()
self.execute(create_sql)
return self
def m2m_table_candidates(self, table: str, other_table: str) -> List[str]:
"""
Given two table names returns the name of tables that could define a
many-to-many relationship between those two tables, based on having
foreign keys to both of the provided tables.
:param table: Table name
:param other_table: Other table name
"""
candidates = []
tables = {table, other_table}
for table_obj in self.tables:
# Does it have foreign keys to both table and other_table?
has_fks_to = {fk.other_table for fk in table_obj.foreign_keys}
if has_fks_to.issuperset(tables):
candidates.append(table_obj.name)
return candidates
def add_foreign_keys(self, foreign_keys: Iterable[Tuple[str, str, str, str]]):
"""
See :ref:`python_api_add_foreign_keys`.
:param foreign_keys: A list of ``(table, column, other_table, other_column)``
tuples
"""
# foreign_keys is a list of explicit 4-tuples
assert all(
len(fk) == 4 and isinstance(fk, (list, tuple)) for fk in foreign_keys
), "foreign_keys must be a list of 4-tuples, (table, column, other_table, other_column)"
foreign_keys_to_create = []
# Verify that all tables and columns exist
for table, column, other_table, other_column in foreign_keys:
if not self[table].exists():
raise AlterError("No such table: {}".format(table))
table_obj = self[table]
if not isinstance(table_obj, Table):
raise AlterError("Must be a table, not a view: {}".format(table))
table_obj = cast(Table, table_obj)
if column not in table_obj.columns_dict:
raise AlterError("No such column: {} in {}".format(column, table))
if not self[other_table].exists():
raise AlterError("No such other_table: {}".format(other_table))
if (
other_column != "rowid"
and other_column not in self[other_table].columns_dict
):
raise AlterError(
"No such other_column: {} in {}".format(other_column, other_table)
)
# We will silently skip foreign keys that exist already
if not any(
fk
for fk in table_obj.foreign_keys
if fk.column == column
and fk.other_table == other_table
and fk.other_column == other_column
):
foreign_keys_to_create.append(
(table, column, other_table, other_column)
)
# Group them by table
by_table: Dict[str, List] = {}
for fk in foreign_keys_to_create:
by_table.setdefault(fk[0], []).append(fk)
for table, fks in by_table.items():
cast(Table, self[table]).transform(add_foreign_keys=fks)
self.vacuum()
def index_foreign_keys(self):
"Create indexes for every foreign key column on every table in the database."
for table_name in self.table_names():
table = self[table_name]
existing_indexes = {
i.columns[0] for i in table.indexes if len(i.columns) == 1
}
for fk in table.foreign_keys:
if fk.column not in existing_indexes:
table.create_index([fk.column], find_unique_name=True)
def vacuum(self):
"Run a SQLite ``VACUUM`` against the database."
self.execute("VACUUM;")
def analyze(self, name=None):
"""
Run ``ANALYZE`` against the entire database or a named table or index.
:param name: Run ``ANALYZE`` against this specific named table or index
"""
sql = "ANALYZE"
if name is not None:
sql += " [{}]".format(name)
self.execute(sql)
def iterdump(self) -> Generator[str, None, None]:
"A sequence of strings representing a SQL dump of the database"
if iterdump:
yield from iterdump(self.conn)
else:
try:
yield from self.conn.iterdump()
except AttributeError:
raise AttributeError(
"conn.iterdump() not found - try pip install sqlite-dump"
)
def init_spatialite(self, path: Optional[str] = None) -> bool:
"""
The ``init_spatialite`` method will load and initialize the SpatiaLite extension.
The ``path`` argument should be an absolute path to the compiled extension, which
can be found using ``find_spatialite``.
Returns ``True`` if SpatiaLite was successfully initialized.
.. code-block:: python
from sqlite_utils.db import Database
from sqlite_utils.utils import find_spatialite
db = Database("mydb.db")
db.init_spatialite(find_spatialite())
If you've installed SpatiaLite somewhere unexpected (for testing an alternate version, for example)
you can pass in an absolute path:
.. code-block:: python
from sqlite_utils.db import Database
from sqlite_utils.utils import find_spatialite
db = Database("mydb.db")
db.init_spatialite("./local/mod_spatialite.dylib")
:param path: Path to SpatiaLite module on disk
"""
if path is None:
path = find_spatialite()
self.conn.enable_load_extension(True)
self.conn.load_extension(path)
# Initialize SpatiaLite if not yet initialized
if "spatial_ref_sys" in self.table_names():
return False
cursor = self.execute("select InitSpatialMetadata(1)")
result = cursor.fetchone()
return result and bool(result[0])
class Queryable:
def exists(self) -> bool:
"Does this table or view exist yet?"
return False
def __init__(self, db, name):
self.db = db
self.name = name
def count_where(
self,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
) -> int:
"""
Executes ``SELECT count(*) FROM table WHERE ...`` and returns a count.
:param where: SQL where fragment to use, for example ``id > ?``
:param where_args: Parameters to use with that fragment - an iterable for ``id > ?``
parameters, or a dictionary for ``id > :id``
"""
sql = "select count(*) from [{}]".format(self.name)
if where is not None:
sql += " where " + where
return self.db.execute(sql, where_args or []).fetchone()[0]
def execute_count(self):
# Backwards compatibility, see https://github.com/simonw/sqlite-utils/issues/305#issuecomment-890713185
return self.count_where()
@property
def count(self) -> int:
"A count of the rows in this table or view."
return self.count_where()
@property
def rows(self) -> Generator[dict, None, None]:
"Iterate over every dictionaries for each row in this table or view."
return self.rows_where()
def rows_where(
self,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
order_by: Optional[str] = None,
select: str = "*",
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Generator[dict, None, None]:
"""
Iterate over every row in this table or view that matches the specified where clause.
Returns each row as a dictionary. See :ref:`python_api_rows` for more details.
:param where: SQL where fragment to use, for example ``id > ?``
:param where_args: Parameters to use with that fragment - an iterable for ``id > ?``
parameters, or a dictionary for ``id > :id``
:param order_by: Column or fragment of SQL to order by
:param select: Comma-separated list of columns to select - defaults to ``*``
:param limit: Integer number of rows to limit to
:param offset: Integer for SQL offset
"""
if not self.exists():
return
sql = "select {} from [{}]".format(select, self.name)
if where is not None:
sql += " where " + where
if order_by is not None:
sql += " order by " + order_by
if limit is not None:
sql += " limit {}".format(limit)
if offset is not None:
sql += " offset {}".format(offset)
cursor = self.db.execute(sql, where_args or [])
columns = [c[0] for c in cursor.description]
for row in cursor:
yield dict(zip(columns, row))
def pks_and_rows_where(
self,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Generator[Tuple[Any, Dict], None, None]:
"""
Like ``.rows_where()`` but returns ``(pk, row)`` pairs - ``pk`` can be a single value or tuple.
:param where: SQL where fragment to use, for example ``id > ?``
:param where_args: Parameters to use with that fragment - an iterable for ``id > ?``
parameters, or a dictionary for ``id > :id``
:param order_by: Column or fragment of SQL to order by
:param select: Comma-separated list of columns to select - defaults to ``*``
:param limit: Integer number of rows to limit to
:param offset: Integer for SQL offset
"""
column_names = [column.name for column in self.columns]
pks = [column.name for column in self.columns if column.is_pk]
if not pks:
column_names.insert(0, "rowid")
pks = ["rowid"]
select = ",".join("[{}]".format(column_name) for column_name in column_names)
for row in self.rows_where(
select=select,
where=where,
where_args=where_args,
order_by=order_by,
limit=limit,
offset=offset,
):
row_pk = tuple(row[pk] for pk in pks)
if len(row_pk) == 1:
row_pk = row_pk[0]
yield row_pk, row
@property
def columns(self) -> List["Column"]:
"List of :ref:`Columns <reference_db_other_column>` representing the columns in this table or view."
if not self.exists():
return []
rows = self.db.execute("PRAGMA table_info([{}])".format(self.name)).fetchall()
return [Column(*row) for row in rows]
@property
def columns_dict(self) -> Dict[str, Any]:
"``{column_name: python-type}`` dictionary representing columns in this table or view."
return {column.name: column_affinity(column.type) for column in self.columns}
@property
def schema(self) -> str:
"SQL schema for this table or view."
return self.db.execute(
"select sql from sqlite_master where name = ?", (self.name,)
).fetchone()[0]
class Table(Queryable):
"""
Tables should usually be initialized using the ``db.table(table_name)`` or
``db[table_name]`` methods.
The following optional parameters can be passed to ``db.table(table_name, ...)``:
:param db: Provided by ``db.table(table_name)``
:param name: Provided by ``db.table(table_name)``
:param pk: Name of the primary key column, or tuple of columns
:param foreign_keys: List of foreign key definitions
:param column_order: List of column names in the order they should be in the table
:param not_null: List of columns that cannot be null
:param defaults: Dictionary of column names and default values
:param batch_size: Integer number of rows to insert at a time
:param hash_id: If True, use a hash of the row values as the primary key
:param hash_id_columns: List of columns to use for the hash_id
:param alter: If True, automatically alter the table if it doesn't match the schema
:param ignore: If True, ignore rows that already exist when inserting
:param replace: If True, replace rows that already exist when inserting
:param extracts: Dictionary or list of column names to extract into a separate table on inserts
:param conversions: Dictionary of column names and conversion functions
:param columns: Dictionary of column names to column types
:param strict: If True, apply STRICT mode to table
"""
#: The ``rowid`` of the last inserted, updated or selected row.
last_rowid: Optional[int] = None
#: The primary key of the last inserted, updated or selected row.
last_pk: Optional[Any] = None
def __init__(
self,
db: Database,
name: str,
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
batch_size: int = 100,
hash_id: Optional[str] = None,
hash_id_columns: Optional[Iterable[str]] = None,
alter: bool = False,
ignore: bool = False,
replace: bool = False,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
conversions: Optional[dict] = None,
columns: Optional[Dict[str, Any]] = None,
strict: bool = False,
):
super().__init__(db, name)
self._defaults = dict(
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
batch_size=batch_size,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
alter=alter,
ignore=ignore,
replace=replace,
extracts=extracts,
conversions=conversions or {},
columns=columns,
strict=strict,
)
def __repr__(self) -> str:
return "<Table {}{}>".format(
self.name,
(
" (does not exist yet)"
if not self.exists()
else " ({})".format(", ".join(c.name for c in self.columns))
),
)
@property
def count(self) -> int:
"Count of the rows in this table - optionally from the table count cache, if configured."
if self.db.use_counts_table:
counts = self.db.cached_counts([self.name])
if counts:
return next(iter(counts.values()))
return self.count_where()
def exists(self) -> bool:
return self.name in self.db.table_names()
@property
def pks(self) -> List[str]:
"Primary key columns for this table."
names = [column.name for column in self.columns if column.is_pk]
if not names:
names = ["rowid"]
return names
@property
def use_rowid(self) -> bool:
"Does this table use ``rowid`` for its primary key (no other primary keys are specified)?"
return not any(column for column in self.columns if column.is_pk)
def get(self, pk_values: Union[list, tuple, str, int]) -> dict:
"""
Return row (as dictionary) for the specified primary key.
Raises ``sqlite_utils.db.NotFoundError`` if a matching row cannot be found.
:param pk_values: A single value, or a tuple of values for tables that have a compound primary key
"""
if not isinstance(pk_values, (list, tuple)):
pk_values = [pk_values]
pks = self.pks
last_pk = pk_values[0] if len(pks) == 1 else pk_values
if len(pks) != len(pk_values):
raise NotFoundError(
"Need {} primary key value{}".format(
len(pks), "" if len(pks) == 1 else "s"
)
)
wheres = ["[{}] = ?".format(pk_name) for pk_name in pks]
rows = self.rows_where(" and ".join(wheres), pk_values)
try:
row = list(rows)[0]
self.last_pk = last_pk
return row
except IndexError:
raise NotFoundError
@property
def foreign_keys(self) -> List["ForeignKey"]:
"List of foreign keys defined on this table."
fks = []
for row in self.db.execute(
"PRAGMA foreign_key_list([{}])".format(self.name)
).fetchall():
if row is not None:
id, seq, table_name, from_, to_, on_update, on_delete, match = row
fks.append(
ForeignKey(
table=self.name,
column=from_,
other_table=table_name,
other_column=to_,
)
)
return fks
@property
def virtual_table_using(self) -> Optional[str]:
"Type of virtual table, or ``None`` if this is not a virtual table."
match = _virtual_table_using_re.match(self.schema)
if match is None:
return None
return match.groupdict()["using"].upper()
@property
def indexes(self) -> List[Index]:
"List of indexes defined on this table."
sql = 'PRAGMA index_list("{}")'.format(self.name)
indexes = []
for row in self.db.execute_returning_dicts(sql):
index_name = row["name"]
index_name_quoted = (
'"{}"'.format(index_name)
if not index_name.startswith('"')
else index_name
)
column_sql = "PRAGMA index_info({})".format(index_name_quoted)
columns = []
for seqno, cid, name in self.db.execute(column_sql).fetchall():
columns.append(name)
row["columns"] = columns
# These columns may be missing on older SQLite versions:
for key, default in {"origin": "c", "partial": 0}.items():
if key not in row:
row[key] = default
indexes.append(Index(**row))
return indexes
@property
def xindexes(self) -> List[XIndex]:
"List of indexes defined on this table using the more detailed ``XIndex`` format."
sql = 'PRAGMA index_list("{}")'.format(self.name)
indexes = []
for row in self.db.execute_returning_dicts(sql):
index_name = row["name"]
index_name_quoted = (
'"{}"'.format(index_name)
if not index_name.startswith('"')
else index_name
)
column_sql = "PRAGMA index_xinfo({})".format(index_name_quoted)
index_columns = []
for info in self.db.execute(column_sql).fetchall():
index_columns.append(XIndexColumn(*info))
indexes.append(XIndex(index_name, index_columns))
return indexes
@property
def triggers(self) -> List[Trigger]:
"List of triggers defined on this table."
return [
Trigger(*r)
for r in self.db.execute(
"select name, tbl_name, sql from sqlite_master where type = 'trigger'"
" and tbl_name = ?",
(self.name,),
).fetchall()
]
@property
def triggers_dict(self) -> Dict[str, str]:
"``{trigger_name: sql}`` dictionary of triggers defined on this table."
return {trigger.name: trigger.sql for trigger in self.triggers}
@property
def default_values(self) -> Dict[str, Any]:
"``{column_name: default_value}`` dictionary of default values for columns in this table."
return {
column.name: _decode_default_value(column.default_value)
for column in self.columns
if column.default_value is not None
}
@property
def strict(self) -> bool:
"Is this a STRICT table?"
table_suffix = self.schema.split(")")[-1].strip().upper()
table_options = [bit.strip() for bit in table_suffix.split(",")]
return "STRICT" in table_options
def create(
self,
columns: Dict[str, Any],
pk: Optional[Any] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
hash_id: Optional[str] = None,
hash_id_columns: Optional[Iterable[str]] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
if_not_exists: bool = False,
replace: bool = False,
ignore: bool = False,
transform: bool = False,
strict: bool = False,
) -> "Table":
"""
Create a table with the specified columns.
See :ref:`python_api_explicit_create` for full details.
:param columns: Dictionary mapping column names to their types, for example ``{"name": str, "age": int}``
:param pk: String name of column to use as a primary key, or a tuple of strings for a compound primary key covering multiple columns
:param foreign_keys: List of foreign key definitions for this table
:param column_order: List specifying which columns should come first
:param not_null: List of columns that should be created as ``NOT NULL``
:param defaults: Dictionary specifying default values for columns
:param hash_id: Name of column to be used as a primary key containing a hash of the other columns
:param hash_id_columns: List of columns to be used when calculating the hash ID for a row
:param extracts: List or dictionary of columns to be extracted during inserts, see :ref:`python_api_extracts`
:param if_not_exists: Use ``CREATE TABLE IF NOT EXISTS``
:param replace: Drop and replace table if it already exists
:param ignore: Silently do nothing if table already exists
:param transform: If table already exists transform it to fit the specified schema
:param strict: Apply STRICT mode to table
"""
columns = {name: value for (name, value) in columns.items()}
with self.db.conn:
self.db.create_table(
self.name,
columns,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
extracts=extracts,
if_not_exists=if_not_exists,
replace=replace,
ignore=ignore,
transform=transform,
strict=strict,
)
return self
def duplicate(self, new_name: str) -> "Table":
"""
Create a duplicate of this table, copying across the schema and all row data.
:param new_name: Name of the new table
"""
if not self.exists():
raise NoTable(f"Table {self.name} does not exist")
with self.db.conn:
sql = "CREATE TABLE [{new_table}] AS SELECT * FROM [{table}];".format(
new_table=new_name,
table=self.name,
)
self.db.execute(sql)
return self.db[new_name]
def transform(
self,
*,
types: Optional[dict] = None,
rename: Optional[dict] = None,
drop: Optional[Iterable] = None,
pk: Optional[Any] = DEFAULT,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
drop_foreign_keys: Optional[Iterable[str]] = None,
add_foreign_keys: Optional[ForeignKeysType] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
keep_table: Optional[str] = None,
) -> "Table":
"""
Apply an advanced alter table, including operations that are not supported by
``ALTER TABLE`` in SQLite itself.
See :ref:`python_api_transform` for full details.
:param types: Columns that should have their type changed, for example ``{"weight": float}``
:param rename: Columns to rename, for example ``{"headline": "title"}``
:param drop: Columns to drop
:param pk: New primary key for the table
:param not_null: Columns to set as ``NOT NULL``
:param defaults: Default values for columns
:param drop_foreign_keys: Names of columns that should have their foreign key constraints removed
:param add_foreign_keys: List of foreign keys to add to the table
:param foreign_keys: List of foreign keys to set for the table, replacing any existing foreign keys
:param column_order: List of strings specifying a full or partial column order
to use when creating the table
:param keep_table: If specified, the existing table will be renamed to this and will not be
dropped
"""
assert self.exists(), "Cannot transform a table that doesn't exist yet"
sqls = self.transform_sql(
types=types,
rename=rename,
drop=drop,
pk=pk,
not_null=not_null,
defaults=defaults,
drop_foreign_keys=drop_foreign_keys,
add_foreign_keys=add_foreign_keys,
foreign_keys=foreign_keys,
column_order=column_order,
keep_table=keep_table,
)
pragma_foreign_keys_was_on = self.db.execute("PRAGMA foreign_keys").fetchone()[
0
]
try:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=0;")
with self.db.conn:
for sql in sqls:
self.db.execute(sql)
# Run the foreign_key_check before we commit
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_key_check;")
finally:
if pragma_foreign_keys_was_on:
self.db.execute("PRAGMA foreign_keys=1;")
return self
def transform_sql(
self,
*,
types: Optional[dict] = None,
rename: Optional[dict] = None,
drop: Optional[Iterable] = None,
pk: Optional[Any] = DEFAULT,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
drop_foreign_keys: Optional[Iterable] = None,
add_foreign_keys: Optional[ForeignKeysType] = None,
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
tmp_suffix: Optional[str] = None,
keep_table: Optional[str] = None,
) -> List[str]:
"""
Return a list of SQL statements that should be executed in order to apply this transformation.
:param types: Columns that should have their type changed, for example ``{"weight": float}``
:param rename: Columns to rename, for example ``{"headline": "title"}``
:param drop: Columns to drop
:param pk: New primary key for the table
:param not_null: Columns to set as ``NOT NULL``
:param defaults: Default values for columns
:param drop_foreign_keys: Names of columns that should have their foreign key constraints removed
:param add_foreign_keys: List of foreign keys to add to the table
:param foreign_keys: List of foreign keys to set for the table, replacing any existing foreign keys
:param column_order: List of strings specifying a full or partial column order
to use when creating the table
:param tmp_suffix: Suffix to use for the temporary table name
:param keep_table: If specified, the existing table will be renamed to this and will not be
dropped
"""
types = types or {}
rename = rename or {}
drop = drop or set()
create_table_foreign_keys: List[ForeignKeyIndicator] = []
if foreign_keys is not None:
if add_foreign_keys is not None:
raise ValueError(
"Cannot specify both foreign_keys and add_foreign_keys"
)
if drop_foreign_keys is not None:
raise ValueError(
"Cannot specify both foreign_keys and drop_foreign_keys"
)
create_table_foreign_keys.extend(foreign_keys)
else:
# Construct foreign_keys from current, plus add_foreign_keys, minus drop_foreign_keys
create_table_foreign_keys = []
for table, column, other_table, other_column in self.foreign_keys:
# Copy over old foreign keys, unless we are dropping them
if (drop_foreign_keys is None) or (column not in drop_foreign_keys):
create_table_foreign_keys.append(
ForeignKey(
table,
rename.get(column) or column,
other_table,
other_column,
)
)
# Add new foreign keys
if add_foreign_keys is not None:
for fk in self.db.resolve_foreign_keys(self.name, add_foreign_keys):
create_table_foreign_keys.append(
ForeignKey(
self.name,
rename.get(fk.column) or fk.column,
fk.other_table,
fk.other_column,
)
)
new_table_name = "{}_new_{}".format(
self.name, tmp_suffix or os.urandom(6).hex()
)
current_column_pairs = list(self.columns_dict.items())
new_column_pairs = []
copy_from_to = {column: column for column, _ in current_column_pairs}
for name, type_ in current_column_pairs:
type_ = types.get(name) or type_
if name in drop:
del [copy_from_to[name]]
continue
new_name = rename.get(name) or name
new_column_pairs.append((new_name, type_))
copy_from_to[name] = new_name
if pk is DEFAULT:
pks_renamed = tuple(
rename.get(p.name) or p.name for p in self.columns if p.is_pk
)
if len(pks_renamed) == 1:
pk = pks_renamed[0]
else:
pk = pks_renamed
# not_null may be a set or dict, need to convert to a set
create_table_not_null = {
rename.get(c.name) or c.name
for c in self.columns
if c.notnull
if c.name not in drop
}
if isinstance(not_null, dict):
# Remove any columns with a value of False
for key, value in not_null.items():
# Column may have been renamed
key = rename.get(key) or key
if value is False and key in create_table_not_null:
create_table_not_null.remove(key)
else:
create_table_not_null.add(key)
elif isinstance(not_null, set):
create_table_not_null.update((rename.get(k) or k) for k in not_null)
elif not not_null:
pass
else:
assert False, "not_null must be a dict or a set or None, it was {}".format(
repr(not_null)
)
# defaults=
create_table_defaults = {
(rename.get(c.name) or c.name): c.default_value
for c in self.columns
if c.default_value is not None and c.name not in drop
}
if defaults is not None:
create_table_defaults.update(
{rename.get(c) or c: v for c, v in defaults.items()}
)
if column_order is not None:
column_order = [rename.get(col) or col for col in column_order]
sqls = []
sqls.append(
self.db.create_table_sql(
new_table_name,
dict(new_column_pairs),
pk=pk,
not_null=create_table_not_null,
defaults=create_table_defaults,
foreign_keys=create_table_foreign_keys,
column_order=column_order,
strict=self.strict,
).strip()
)
# Copy across data, respecting any renamed columns
new_cols = []
old_cols = []
for from_, to_ in copy_from_to.items():
old_cols.append(from_)
new_cols.append(to_)
# Ensure rowid is copied too
if "rowid" not in new_cols:
new_cols.insert(0, "rowid")
old_cols.insert(0, "rowid")
copy_sql = "INSERT INTO [{new_table}] ({new_cols})\n SELECT {old_cols} FROM [{old_table}];".format(
new_table=new_table_name,
old_table=self.name,
old_cols=", ".join("[{}]".format(col) for col in old_cols),
new_cols=", ".join("[{}]".format(col) for col in new_cols),
)
sqls.append(copy_sql)
# Drop (or keep) the old table
if keep_table:
sqls.append(
"ALTER TABLE [{}] RENAME TO [{}];".format(self.name, keep_table)
)
else:
sqls.append("DROP TABLE [{}];".format(self.name))
# Rename the new one
sqls.append(
"ALTER TABLE [{}] RENAME TO [{}];".format(new_table_name, self.name)
)
# Re-add existing indexes
for index in self.indexes:
if index.origin != "pk":
index_sql = self.db.execute(
"""SELECT sql FROM sqlite_master WHERE type = 'index' AND name = :index_name;""",
{"index_name": index.name},
).fetchall()[0][0]
if index_sql is None:
raise TransformError(
f"Index '{index.name}' on table '{self.name}' does not have a "
"CREATE INDEX statement. You must manually drop this index prior to running this "
"transformation and manually recreate the new index after running this transformation."
)
if keep_table:
sqls.append(f"DROP INDEX IF EXISTS [{index.name}];")
for col in index.columns:
if col in rename.keys() or col in drop:
raise TransformError(
f"Index '{index.name}' column '{col}' is not in updated table '{self.name}'. "
f"You must manually drop this index prior to running this transformation "
f"and manually recreate the new index after running this transformation. "
f"The original index sql statement is: `{index_sql}`. No changes have been applied to this table."
)
sqls.append(index_sql)
return sqls
def extract(
self,
columns: Union[str, Iterable[str]],
table: Optional[str] = None,
fk_column: Optional[str] = None,
rename: Optional[Dict[str, str]] = None,
) -> "Table":
"""
Extract specified columns into a separate table.
See :ref:`python_api_extract` for details.
:param columns: Single column or list of columns that should be extracted
:param table: Name of table in which the new records should be created
:param fk_column: Name of the foreign key column to populate in the original table
:param rename: Dictionary of columns that should be renamed when populating the new table
"""
rename = rename or {}
if isinstance(columns, str):
columns = [columns]
if not set(columns).issubset(self.columns_dict.keys()):
raise InvalidColumns(
"Invalid columns {} for table with columns {}".format(
columns, list(self.columns_dict.keys())
)
)
table = table or "_".join(columns)
lookup_table = self.db[table]
fk_column = fk_column or "{}_id".format(table)
magic_lookup_column = "{}_{}".format(fk_column, os.urandom(6).hex())
# Populate the lookup table with all of the extracted unique values
lookup_columns_definition = {
(rename.get(col) or col): typ
for col, typ in self.columns_dict.items()
if col in columns
}
if lookup_table.exists():
if not set(lookup_columns_definition.items()).issubset(
lookup_table.columns_dict.items()
):
raise InvalidColumns(
"Lookup table {} already exists but does not have columns {}".format(
table, lookup_columns_definition
)
)
else:
lookup_table.create(
{
**{
"id": int,
},
**lookup_columns_definition,
},
pk="id",
)
lookup_columns = [(rename.get(col) or col) for col in columns]
lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True)
self.db.execute(
"INSERT OR IGNORE INTO [{lookup_table}] ({lookup_columns}) SELECT DISTINCT {table_cols} FROM [{table}]".format(
lookup_table=table,
lookup_columns=", ".join("[{}]".format(c) for c in lookup_columns),
table_cols=", ".join("[{}]".format(c) for c in columns),
table=self.name,
)
)
# Now add the new fk_column
self.add_column(magic_lookup_column, int)
# And populate it
self.db.execute(
"UPDATE [{table}] SET [{magic_lookup_column}] = (SELECT id FROM [{lookup_table}] WHERE {where})".format(
table=self.name,
magic_lookup_column=magic_lookup_column,
lookup_table=table,
where=" AND ".join(
"[{table}].[{column}] IS [{lookup_table}].[{lookup_column}]".format(
table=self.name,
lookup_table=table,
column=column,
lookup_column=rename.get(column) or column,
)
for column in columns
),
)
)
# Figure out the right column order
column_order = []
for c in self.columns:
if c.name in columns and magic_lookup_column not in column_order:
column_order.append(magic_lookup_column)
elif c.name == magic_lookup_column:
continue
else:
column_order.append(c.name)
# Drop the unnecessary columns and rename lookup column
self.transform(
drop=set(columns),
rename={magic_lookup_column: fk_column},
column_order=column_order,
)
# And add the foreign key constraint
self.add_foreign_key(fk_column, table, "id")
return self
def create_index(
self,
columns: Iterable[Union[str, DescIndex]],
index_name: Optional[str] = None,
unique: bool = False,
if_not_exists: bool = False,
find_unique_name: bool = False,
analyze: bool = False,
):
"""
Create an index on this table.
:param columns: A single columns or list of columns to index. These can be strings or,
to create an index using the column in descending order, ``db.DescIndex(column_name)`` objects.
:param index_name: The name to use for the new index. Defaults to the column names joined on ``_``.
:param unique: Should the index be marked as unique, forcing unique values?
:param if_not_exists: Only create the index if one with that name does not already exist.
:param find_unique_name: If ``index_name`` is not provided and the automatically derived name
already exists, keep incrementing a suffix number to find an available name.
:param analyze: Run ``ANALYZE`` against this index after creating it.
See :ref:`python_api_create_index`.
"""
if index_name is None:
index_name = "idx_{}_{}".format(
self.name.replace(" ", "_"), "_".join(columns)
)
columns_sql = []
for column in columns:
if isinstance(column, DescIndex):
fmt = "[{}] desc"
else:
fmt = "[{}]"
columns_sql.append(fmt.format(column))
suffix = None
created_index_name = None
while True:
created_index_name = (
"{}_{}".format(index_name, suffix) if suffix else index_name
)
sql = (
textwrap.dedent(
"""
CREATE {unique}INDEX {if_not_exists}[{index_name}]
ON [{table_name}] ({columns});
"""
)
.strip()
.format(
index_name=created_index_name,
table_name=self.name,
columns=", ".join(columns_sql),
unique="UNIQUE " if unique else "",
if_not_exists="IF NOT EXISTS " if if_not_exists else "",
)
)
try:
self.db.execute(sql)
break
except OperationalError as e:
# find_unique_name=True - try again if 'index ... already exists'
arg = e.args[0]
if (
find_unique_name
and arg.startswith("index ")
and arg.endswith(" already exists")
):
if suffix is None:
suffix = 2
else:
suffix += 1
continue
else:
raise e
if analyze:
self.db.analyze(created_index_name)
return self
def add_column(
self,
col_name: str,
col_type: Optional[Any] = None,
fk: Optional[str] = None,
fk_col: Optional[str] = None,
not_null_default: Optional[Any] = None,
):
"""
Add a column to this table. See :ref:`python_api_add_column`.
:param col_name: Name of the new column
:param col_type: Column type - a Python type such as ``str`` or a SQLite type string such as ``"BLOB"``
:param fk: Name of a table that this column should be a foreign key reference to
:param fk_col: Column in the foreign key table that this should reference
:param not_null_default: Set this column to ``not null`` and give it this default value
"""
fk_col_type = None
if fk is not None:
# fk must be a valid table
if fk not in self.db.table_names():
raise AlterError("table '{}' does not exist".format(fk))
# if fk_col specified, must be a valid column
if fk_col is not None:
if fk_col not in self.db[fk].columns_dict:
raise AlterError("table '{}' has no column {}".format(fk, fk_col))
else:
# automatically set fk_col to first primary_key of fk table
pks = [c for c in self.db[fk].columns if c.is_pk]
if pks:
fk_col = pks[0].name
fk_col_type = pks[0].type
else:
fk_col = "rowid"
fk_col_type = "INTEGER"
if col_type is None:
col_type = str
not_null_sql = None
if not_null_default is not None:
not_null_sql = "NOT NULL DEFAULT {}".format(
self.db.quote_default_value(not_null_default)
)
sql = "ALTER TABLE [{table}] ADD COLUMN [{col_name}] {col_type}{not_null_default};".format(
table=self.name,
col_name=col_name,
col_type=fk_col_type or COLUMN_TYPE_MAPPING[col_type],
not_null_default=(" " + not_null_sql) if not_null_sql else "",
)
self.db.execute(sql)
if fk is not None:
self.add_foreign_key(col_name, fk, fk_col)
return self
def drop(self, ignore: bool = False):
"""
Drop this table.
:param ignore: Set to ``True`` to ignore the error if the table does not exist
"""
try:
self.db.execute("DROP TABLE [{}]".format(self.name))
except sqlite3.OperationalError:
if not ignore:
raise
def guess_foreign_table(self, column: str) -> str:
"""
For a given column, suggest another table that might be referenced by this
column should it be used as a foreign key.
For example, a column called ``tag_id`` or ``tag`` or ``tags`` might suggest
a ``tag`` table, if one exists.
If no candidates can be found, raises a ``NoObviousTable`` exception.
:param column: Name of column
"""
column = column.lower()
possibilities = [column]
if column.endswith("_id"):
column_without_id = column[:-3]
possibilities.append(column_without_id)
if not column_without_id.endswith("s"):
possibilities.append(column_without_id + "s")
elif not column.endswith("s"):
possibilities.append(column + "s")
existing_tables = {t.lower(): t for t in self.db.table_names()}
for table in possibilities:
if table in existing_tables:
return existing_tables[table]
# If we get here there's no obvious candidate - raise an error
raise NoObviousTable(
"No obvious foreign key table for column '{}' - tried {}".format(
column, repr(possibilities)
)
)
def guess_foreign_column(self, other_table: str):
pks = [c for c in self.db[other_table].columns if c.is_pk]
if len(pks) != 1:
raise BadPrimaryKey(
"Could not detect single primary key for table '{}'".format(other_table)
)
else:
return pks[0].name
def add_foreign_key(
self,
column: str,
other_table: Optional[str] = None,
other_column: Optional[str] = None,
ignore: bool = False,
):
"""
Alter the schema to mark the specified column as a foreign key to another table.
:param column: The column to mark as a foreign key.
:param other_table: The table it refers to - if omitted, will be guessed based on the column name.
:param other_column: The column on the other table it - if omitted, will be guessed.
:param ignore: Set this to ``True`` to ignore an existing foreign key - otherwise a ``AlterError`` will be raised.
"""
# Ensure column exists
if column not in self.columns_dict:
raise AlterError("No such column: {}".format(column))
# If other_table is not specified, attempt to guess it from the column
if other_table is None:
other_table = self.guess_foreign_table(column)
# If other_column is not specified, detect the primary key on other_table
if other_column is None:
other_column = self.guess_foreign_column(other_table)
# Soundness check that the other column exists
if (
not [c for c in self.db[other_table].columns if c.name == other_column]
and other_column != "rowid"
):
raise AlterError("No such column: {}.{}".format(other_table, other_column))
# Check we do not already have an existing foreign key
if any(
fk
for fk in self.foreign_keys
if fk.column == column
and fk.other_table == other_table
and fk.other_column == other_column
):
if ignore:
return self
else:
raise AlterError(
"Foreign key already exists for {} => {}.{}".format(
column, other_table, other_column
)
)
self.db.add_foreign_keys([(self.name, column, other_table, other_column)])
return self
def enable_counts(self):
"""
Set up triggers to update a cache of the count of rows in this table.
See :ref:`python_api_cached_table_counts` for details.
"""
sql = (
textwrap.dedent(
"""
{create_counts_table}
CREATE TRIGGER IF NOT EXISTS [{table}{counts_table}_insert] AFTER INSERT ON [{table}]
BEGIN
INSERT OR REPLACE INTO [{counts_table}]
VALUES (
{table_quoted},
COALESCE(
(SELECT count FROM [{counts_table}] WHERE [table] = {table_quoted}),
0
) + 1
);
END;
CREATE TRIGGER IF NOT EXISTS [{table}{counts_table}_delete] AFTER DELETE ON [{table}]
BEGIN
INSERT OR REPLACE INTO [{counts_table}]
VALUES (
{table_quoted},
COALESCE(
(SELECT count FROM [{counts_table}] WHERE [table] = {table_quoted}),
0
) - 1
);
END;
INSERT OR REPLACE INTO _counts VALUES ({table_quoted}, (select count(*) from [{table}]));
"""
)
.strip()
.format(
create_counts_table=_COUNTS_TABLE_CREATE_SQL.format(
self.db._counts_table_name
),
counts_table=self.db._counts_table_name,
table=self.name,
table_quoted=self.db.quote(self.name),
)
)
with self.db.conn:
self.db.conn.executescript(sql)
self.db.use_counts_table = True
@property
def has_counts_triggers(self) -> bool:
"Does this table have triggers setup to update cached counts?"
trigger_names = {
"{table}{counts_table}_{suffix}".format(
counts_table=self.db._counts_table_name, table=self.name, suffix=suffix
)
for suffix in ["insert", "delete"]
}
return trigger_names.issubset(self.triggers_dict.keys())
def enable_fts(
self,
columns: Iterable[str],
fts_version: str = "FTS5",
create_triggers: bool = False,
tokenize: Optional[str] = None,
replace: bool = False,
):
"""
Enable SQLite full-text search against the specified columns.
See :ref:`python_api_fts` for more details.
:param columns: List of column names to include in the search index.
:param fts_version: FTS version to use - defaults to ``FTS5`` but you may want ``FTS4`` for older SQLite versions.
:param create_triggers: Should triggers be created to keep the search index up-to-date? Defaults to ``False``.
:param tokenize: Custom SQLite tokenizer to use, for example ``"porter"`` to enable Porter stemming.
:param replace: Should any existing FTS index for this table be replaced by the new one?
"""
create_fts_sql = (
textwrap.dedent(
"""
CREATE VIRTUAL TABLE [{table}_fts] USING {fts_version} (
{columns},{tokenize}
content=[{table}]
)
"""
)
.strip()
.format(
table=self.name,
columns=", ".join("[{}]".format(c) for c in columns),
fts_version=fts_version,
tokenize="\n tokenize='{}',".format(tokenize) if tokenize else "",
)
)
should_recreate = False
if replace and self.db["{}_fts".format(self.name)].exists():
# Does the table need to be recreated?
fts_schema = self.db["{}_fts".format(self.name)].schema
if fts_schema != create_fts_sql:
should_recreate = True
expected_triggers = {self.name + suffix for suffix in ("_ai", "_ad", "_au")}
existing_triggers = {t.name for t in self.triggers}
has_triggers = existing_triggers.issuperset(expected_triggers)
if has_triggers != create_triggers:
should_recreate = True
if not should_recreate:
# Table with correct configuration already exists
return self
if should_recreate:
self.disable_fts()
self.db.executescript(create_fts_sql)
self.populate_fts(columns)
if create_triggers:
old_cols = ", ".join("old.[{}]".format(c) for c in columns)
new_cols = ", ".join("new.[{}]".format(c) for c in columns)
triggers = (
textwrap.dedent(
"""
CREATE TRIGGER [{table}_ai] AFTER INSERT ON [{table}] BEGIN
INSERT INTO [{table}_fts] (rowid, {columns}) VALUES (new.rowid, {new_cols});
END;
CREATE TRIGGER [{table}_ad] AFTER DELETE ON [{table}] BEGIN
INSERT INTO [{table}_fts] ([{table}_fts], rowid, {columns}) VALUES('delete', old.rowid, {old_cols});
END;
CREATE TRIGGER [{table}_au] AFTER UPDATE ON [{table}] BEGIN
INSERT INTO [{table}_fts] ([{table}_fts], rowid, {columns}) VALUES('delete', old.rowid, {old_cols});
INSERT INTO [{table}_fts] (rowid, {columns}) VALUES (new.rowid, {new_cols});
END;
"""
)
.strip()
.format(
table=self.name,
columns=", ".join("[{}]".format(c) for c in columns),
old_cols=old_cols,
new_cols=new_cols,
)
)
self.db.executescript(triggers)
return self
def populate_fts(self, columns: Iterable[str]) -> "Table":
"""
Update the associated SQLite full-text search index with the latest data from the
table for the specified columns.
:param columns: Columns to populate the data for
"""
sql = (
textwrap.dedent(
"""
INSERT INTO [{table}_fts] (rowid, {columns})
SELECT rowid, {columns} FROM [{table}];
"""
)
.strip()
.format(
table=self.name, columns=", ".join("[{}]".format(c) for c in columns)
)
)
self.db.executescript(sql)
return self
def disable_fts(self) -> "Table":
"Remove any full-text search index and related triggers configured for this table."
fts_table = self.detect_fts()
if fts_table:
self.db[fts_table].drop()
# Now delete the triggers that related to that table
sql = (
textwrap.dedent(
"""
SELECT name FROM sqlite_master
WHERE type = 'trigger'
AND sql LIKE '% INSERT INTO [{}]%'
"""
)
.strip()
.format(fts_table)
)
trigger_names = []
for row in self.db.execute(sql).fetchall():
trigger_names.append(row[0])
with self.db.conn:
for trigger_name in trigger_names:
self.db.execute("DROP TRIGGER IF EXISTS [{}]".format(trigger_name))
return self
def rebuild_fts(self):
"Run the ``rebuild`` operation against the associated full-text search index table."
fts_table = self.detect_fts()
if fts_table is None:
# Assume this is itself an FTS table
fts_table = self.name
self.db.execute(
"INSERT INTO [{table}]([{table}]) VALUES('rebuild');".format(
table=fts_table
)
)
return self
def detect_fts(self) -> Optional[str]:
"Detect if table has a corresponding FTS virtual table and return it"
sql = textwrap.dedent(
"""
SELECT name FROM sqlite_master
WHERE rootpage = 0
AND (
sql LIKE :like
OR sql LIKE :like2
OR (
tbl_name = :table
AND sql LIKE '%VIRTUAL TABLE%USING FTS%'
)
)
"""
).strip()
args = {
"like": "%VIRTUAL TABLE%USING FTS%content=[{}]%".format(self.name),
"like2": '%VIRTUAL TABLE%USING FTS%content="{}"%'.format(self.name),
"table": self.name,
}
rows = self.db.execute(sql, args).fetchall()
if len(rows) == 0:
return None
else:
return rows[0][0]
def optimize(self) -> "Table":
"Run the ``optimize`` operation against the associated full-text search index table."
fts_table = self.detect_fts()
if fts_table is not None:
self.db.execute(
"""
INSERT INTO [{table}] ([{table}]) VALUES ("optimize");
""".strip().format(
table=fts_table
)
)
return self
def search_sql(
self,
columns: Optional[Iterable[str]] = None,
order_by: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where: Optional[str] = None,
include_rank: bool = False,
) -> str:
""" "
Return SQL string that can be used to execute searches against this table.
:param columns: Columns to search against
:param order_by: Column or SQL expression to sort by
:param limit: SQL limit
:param offset: SQL offset
:param where: Extra SQL fragment for the WHERE clause
:param include_rank: Select the search rank column in the final query
"""
# Pick names for table and rank column that don't clash
original = "original_" if self.name == "original" else "original"
columns_sql = "*"
columns_with_prefix_sql = "[{}].*".format(original)
if columns:
columns_sql = ",\n ".join("[{}]".format(c) for c in columns)
columns_with_prefix_sql = ",\n ".join(
"[{}].[{}]".format(original, c) for c in columns
)
fts_table = self.detect_fts()
assert fts_table, "Full-text search is not configured for table '{}'".format(
self.name
)
virtual_table_using = self.db[fts_table].virtual_table_using
sql = textwrap.dedent(
"""
with {original} as (
select
rowid,
{columns}
from [{dbtable}]{where_clause}
)
select
{columns_with_prefix}
from
[{original}]
join [{fts_table}] on [{original}].rowid = [{fts_table}].rowid
where
[{fts_table}] match :query
order by
{order_by}
{limit_offset}
"""
).strip()
if virtual_table_using == "FTS5":
rank_implementation = "[{}].rank".format(fts_table)
else:
self.db.register_fts4_bm25()
rank_implementation = "rank_bm25(matchinfo([{}], 'pcnalx'))".format(
fts_table
)
if include_rank:
columns_with_prefix_sql += ",\n " + rank_implementation + " rank"
limit_offset = ""
if limit is not None:
limit_offset += " limit {}".format(limit)
if offset is not None:
limit_offset += " offset {}".format(offset)
return sql.format(
dbtable=self.name,
where_clause="\n where {}".format(where) if where else "",
original=original,
columns=columns_sql,
columns_with_prefix=columns_with_prefix_sql,
fts_table=fts_table,
order_by=order_by or rank_implementation,
limit_offset=limit_offset.strip(),
).strip()
def search(
self,
q: str,
order_by: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
include_rank: bool = False,
quote: bool = False,
) -> Generator[dict, None, None]:
"""
Execute a search against this table using SQLite full-text search, returning a sequence of
dictionaries for each row.
:param q: Terms to search for
:param order_by: Defaults to order by rank, or specify a column here.
:param columns: List of columns to return, defaults to all columns.
:param limit: Optional integer limit for returned rows.
:param offset: Optional integer SQL offset.
:param where: Extra SQL fragment for the WHERE clause
:param where_args: Arguments to use for :param placeholders in the extra WHERE clause
:param include_rank: Select the search rank column in the final query
:param quote: Apply quoting to disable any special characters in the search query
See :ref:`python_api_fts_search`.
"""
args = {"query": self.db.quote_fts(q) if quote else q}
if where_args and "query" in where_args:
raise ValueError(
"'query' is a reserved key and cannot be passed to where_args for .search()"
)
if where_args:
args.update(where_args)
cursor = self.db.execute(
self.search_sql(
order_by=order_by,
columns=columns,
limit=limit,
offset=offset,
where=where,
include_rank=include_rank,
),
args,
)
columns = [c[0] for c in cursor.description]
for row in cursor:
yield dict(zip(columns, row))
def value_or_default(self, key, value):
return self._defaults[key] if value is DEFAULT else value
def delete(self, pk_values: Union[list, tuple, str, int, float]) -> "Table":
"""
Delete row matching the specified primary key.
:param pk_values: A single value, or a tuple of values for tables that have a compound primary key
"""
if not isinstance(pk_values, (list, tuple)):
pk_values = [pk_values]
self.get(pk_values)
wheres = ["[{}] = ?".format(pk_name) for pk_name in self.pks]
sql = "delete from [{table}] where {wheres}".format(
table=self.name, wheres=" and ".join(wheres)
)
with self.db.conn:
self.db.execute(sql, pk_values)
return self
def delete_where(
self,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
analyze: bool = False,
) -> "Table":
"""
Delete rows matching the specified where clause, or delete all rows in the table.
See :ref:`python_api_delete_where`.
:param where: SQL where fragment to use, for example ``id > ?``
:param where_args: Parameters to use with that fragment - an iterable for ``id > ?``
parameters, or a dictionary for ``id > :id``
:param analyze: Set to ``True`` to run ``ANALYZE`` after the rows have been deleted.
"""
if not self.exists():
return self
sql = "delete from [{}]".format(self.name)
if where is not None:
sql += " where " + where
self.db.execute(sql, where_args or [])
if analyze:
self.analyze()
return self
def update(
self,
pk_values: Union[list, tuple, str, int, float],
updates: Optional[dict] = None,
alter: bool = False,
conversions: Optional[dict] = None,
) -> "Table":
"""
Execute a SQL ``UPDATE`` against the specified row.
See :ref:`python_api_update`.
:param pk_values: The primary key of an individual record - can be a tuple if the
table has a compound primary key.
:param updates: A dictionary mapping columns to their updated values.
:param alter: Set to ``True`` to add any missing columns.
:param conversions: Optional dictionary of SQL functions to apply during the update, for example
``{"mycolumn": "upper(?)"}``.
"""
updates = updates or {}
conversions = conversions or {}
if not isinstance(pk_values, (list, tuple)):
pk_values = [pk_values]
# Soundness check that the record exists (raises error if not):
self.get(pk_values)
if not updates:
return self
args = []
sets = []
wheres = []
pks = self.pks
validate_column_names(updates.keys())
for key, value in updates.items():
sets.append("[{}] = {}".format(key, conversions.get(key, "?")))
args.append(jsonify_if_needed(value))
wheres = ["[{}] = ?".format(pk_name) for pk_name in pks]
args.extend(pk_values)
sql = "update [{table}] set {sets} where {wheres}".format(
table=self.name, sets=", ".join(sets), wheres=" and ".join(wheres)
)
with self.db.conn:
try:
rowcount = self.db.execute(sql, args).rowcount
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns([updates])
rowcount = self.db.execute(sql, args).rowcount
else:
raise
# TODO: Test this works (rolls back) - use better exception:
assert rowcount == 1
self.last_pk = pk_values[0] if len(pks) == 1 else pk_values
return self
def convert(
self,
columns: Union[str, List[str]],
fn: Callable,
output: Optional[str] = None,
output_type: Optional[Any] = None,
drop: bool = False,
multi: bool = False,
where: Optional[str] = None,
where_args: Optional[Union[Iterable, dict]] = None,
show_progress: bool = False,
skip_false: bool = True,
):
"""
Apply conversion function ``fn`` to every value in the specified columns.
:param columns: A single column or list of string column names to convert.
:param fn: A callable that takes a single argument, ``value``, and returns it converted.
:param output: Optional string column name to write the results to (defaults to the input column).
:param output_type: If the output column needs to be created, this is the type that will be used
for the new column.
:param drop: Should the original column be dropped once the conversion is complete?
:param multi: If ``True`` the return value of ``fn(value)`` will be expected to be a
dictionary, and new columns will be created for each key of that dictionary.
:param where: SQL fragment to use as a ``WHERE`` clause to limit the rows to which the conversion
is applied, for example ``age > ?`` or ``age > :age``.
:param where_args: List of arguments (if using ``?``) or a dictionary (if using ``:age``).
:param show_progress: Should a progress bar be displayed?
See :ref:`python_api_convert`.
"""
if isinstance(columns, str):
columns = [columns]
if multi:
return self._convert_multi(
columns[0],
fn,
drop=drop,
where=where,
where_args=where_args,
show_progress=show_progress,
)
if output is not None:
assert len(columns) == 1, "output= can only be used with a single column"
if output not in self.columns_dict:
self.add_column(output, output_type or "text")
todo_count = self.count_where(where, where_args) * len(columns)
with progressbar(length=todo_count, silent=not show_progress) as bar:
def convert_value(v):
bar.update(1)
if skip_false and not v:
return v
return jsonify_if_needed(fn(v))
fn_name = fn.__name__
if fn_name == "<lambda>":
fn_name = f"lambda_{abs(hash(fn))}"
self.db.register_function(convert_value, name=fn_name)
sql = "update [{table}] set {sets}{where};".format(
table=self.name,
sets=", ".join(
[
"[{output_column}] = {fn_name}([{column}])".format(
output_column=output or column,
column=column,
fn_name=fn_name,
)
for column in columns
]
),
where=" where {}".format(where) if where is not None else "",
)
with self.db.conn:
self.db.execute(sql, where_args or [])
if drop:
self.transform(drop=columns)
return self
def _convert_multi(
self, column, fn, drop, show_progress, where=None, where_args=None
):
# First we execute the function
pk_to_values = {}
new_column_types = {}
pks = [column.name for column in self.columns if column.is_pk]
if not pks:
pks = ["rowid"]
with progressbar(
length=self.count, silent=not show_progress, label="1: Evaluating"
) as bar:
for row in self.rows_where(
select=", ".join(
"[{}]".format(column_name) for column_name in (pks + [column])
),
where=where,
where_args=where_args,
):
row_pk = tuple(row[pk] for pk in pks)
if len(row_pk) == 1:
row_pk = row_pk[0]
values = fn(row[column])
if values is not None and not isinstance(values, dict):
raise BadMultiValues(values)
if values:
for key, value in values.items():
new_column_types.setdefault(key, set()).add(type(value))
pk_to_values[row_pk] = values
bar.update(1)
# Add any new columns
columns_to_create = types_for_column_types(new_column_types)
for column_name, column_type in columns_to_create.items():
if column_name not in self.columns_dict:
self.add_column(column_name, column_type)
# Run the updates
with progressbar(
length=self.count, silent=not show_progress, label="2: Updating"
) as bar:
with self.db.conn:
for pk, updates in pk_to_values.items():
self.update(pk, updates)
bar.update(1)
if drop:
self.transform(drop=(column,))
def build_insert_queries_and_params(
self,
extracts,
chunk,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
):
"""
Given a list ``chunk`` of records that should be written to *this* table,
return a list of ``(sql, parameters)`` 2-tuples which, when executed in
order, perform the desired INSERT / UPSERT / REPLACE operation.
"""
if hash_id_columns and hash_id is None:
hash_id = "id"
extracts = resolve_extracts(extracts)
# Build a row-list ready for executemany-style flattening
values: list[list] = []
for record in chunk:
row_vals = []
for col in all_columns:
if col == hash_id:
row_vals.append(hash_record(record, hash_id_columns))
continue
val = record.get(col)
if val is None and not_null and col in not_null:
val = ""
row_vals.append(jsonify_if_needed(val))
values.append(row_vals)
columns_sql = ", ".join(f"[{c}]" for c in all_columns)
placeholder_expr = ", ".join(conversions.get(c, "?") for c in all_columns)
row_placeholders_sql = ", ".join(f"({placeholder_expr})" for _ in values)
flat_params = list(itertools.chain.from_iterable(values))
# replace=True mean INSERT OR REPLACE INTO
if replace:
sql = (
f"INSERT OR REPLACE INTO [{self.name}] "
f"({columns_sql}) VALUES {row_placeholders_sql}"
)
return [(sql, flat_params)]
# If not an upsert it's an INSERT, maybe with OR IGNORE
if not upsert:
or_ignore = ""
if ignore:
or_ignore = " OR IGNORE"
sql = (
f"INSERT{or_ignore} INTO [{self.name}] "
f"({columns_sql}) VALUES {row_placeholders_sql}"
)
return [(sql, flat_params)]
# Everything from here on is for upsert=True
pk_cols = [pk] if isinstance(pk, str) else list(pk)
non_pk_cols = [c for c in all_columns if c not in pk_cols]
conflict_sql = ", ".join(f"[{c}]" for c in pk_cols)
if self.db.supports_on_conflict and not self.db.use_old_upsert:
if non_pk_cols:
# DO UPDATE
assignments = []
for c in non_pk_cols:
if c in conversions:
assignments.append(
f"[{c}] = {conversions[c].replace('?', f'excluded.[{c}]')}"
)
else:
assignments.append(f"[{c}] = excluded.[{c}]")
do_clause = "DO UPDATE SET " + ", ".join(assignments)
else:
# All columns are in the PK – nothing to update.
do_clause = "DO NOTHING"
sql = (
f"INSERT INTO [{self.name}] ({columns_sql}) "
f"VALUES {row_placeholders_sql} "
f"ON CONFLICT({conflict_sql}) {do_clause}"
)
return [(sql, flat_params)]
# At this point we need compatibility UPSERT for SQLite < 3.24.0
# (INSERT OR IGNORE + second UPDATE stage)
queries_and_params: list[tuple[str, list]] = []
insert_sql = (
f"INSERT OR IGNORE INTO [{self.name}] "
f"({columns_sql}) VALUES {row_placeholders_sql}"
)
queries_and_params.append((insert_sql, flat_params))
# If there is nothing to update we are done.
if not non_pk_cols:
return queries_and_params
# We can use UPDATE … FROM (VALUES …) on SQLite ≥ 3.33.0
# Older SQLite versions will run this as one UPDATE per row
# – which is what sqlite-utils did prior to this refactor.
alias_cols_sql = ", ".join(pk_cols + non_pk_cols)
assignments = []
for c in non_pk_cols:
if c in conversions:
assignments.append(f"[{c}] = {conversions[c].replace('?', f'v.[{c}]')}")
else:
assignments.append(f"[{c}] = v.[{c}]")
assignments_sql = ", ".join(assignments)
update_sql = (
f"UPDATE [{self.name}] AS m SET {assignments_sql} "
f"FROM (VALUES {row_placeholders_sql}) "
f"AS v({alias_cols_sql}) "
f"WHERE " + " AND ".join(f"m.[{c}] = v.[{c}]" for c in pk_cols)
)
# Parameters for the UPDATE – pk cols first then non-pk cols
update_params: list = []
for row in values:
row_dict = dict(zip(all_columns, row))
ordered = [row_dict[c] for c in pk_cols + non_pk_cols]
update_params.extend(ordered)
queries_and_params.append((update_sql, update_params))
return queries_and_params
def insert_chunk(
self,
alter,
extracts,
chunk,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
) -> Optional[sqlite3.Cursor]:
queries_and_params = self.build_insert_queries_and_params(
extracts,
chunk,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
)
result = None
with self.db.conn:
for query, params in queries_and_params:
try:
result = self.db.execute(query, params)
except OperationalError as e:
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns, then try again
self.add_missing_columns(chunk)
result = self.db.execute(query, params)
elif e.args[0] == "too many SQL variables":
first_half = chunk[: len(chunk) // 2]
second_half = chunk[len(chunk) // 2 :]
self.insert_chunk(
alter,
extracts,
first_half,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
)
result = self.insert_chunk(
alter,
extracts,
second_half,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
)
else:
raise
return result
def insert(
self,
record: Dict[str, Any],
pk=DEFAULT,
foreign_keys=DEFAULT,
column_order: Optional[Union[List[str], Default]] = DEFAULT,
not_null: Optional[Union[Iterable[str], Default]] = DEFAULT,
defaults: Optional[Union[Dict[str, Any], Default]] = DEFAULT,
hash_id: Optional[Union[str, Default]] = DEFAULT,
hash_id_columns: Optional[Union[Iterable[str], Default]] = DEFAULT,
alter: Optional[Union[bool, Default]] = DEFAULT,
ignore: Optional[Union[bool, Default]] = DEFAULT,
replace: Optional[Union[bool, Default]] = DEFAULT,
extracts: Optional[Union[Dict[str, str], List[str], Default]] = DEFAULT,
conversions: Optional[Union[Dict[str, str], Default]] = DEFAULT,
columns: Optional[Union[Dict[str, Any], Default]] = DEFAULT,
strict: Optional[Union[bool, Default]] = DEFAULT,
) -> "Table":
"""
Insert a single record into the table. The table will be created with a schema that matches
the inserted record if it does not already exist, see :ref:`python_api_creating_tables`.
- ``record`` - required: a dictionary representing the record to be inserted.
The other parameters are optional, and mostly influence how the new table will be created if
that table does not exist yet.
Each of them defaults to ``DEFAULT``, which indicates that the default setting for the current
``Table`` object (specified in the table constructor) should be used.
:param record: Dictionary record to be inserted
:param pk: If creating the table, which column should be the primary key.
:param foreign_keys: See :ref:`python_api_foreign_keys`.
:param column_order: List of strings specifying a full or partial column order
to use when creating the table.
:param not_null: Set of strings specifying columns that should be ``NOT NULL``.
:param defaults: Dictionary specifying default values for specific columns.
:param hash_id: Name of a column to create and use as a primary key, where the
value of that primary key will be derived as a SHA1 hash of the other column values
in the record. ``hash_id="id"`` is a common column name used for this.
:param alter: Boolean, should any missing columns be added automatically?
:param ignore: Boolean, if a record already exists with this primary key, ignore this insert.
:param replace: Boolean, if a record already exists with this primary key, replace it with this new record.
:param extracts: A list of columns to extract to other tables, or a dictionary that maps
``{column_name: other_table_name}``. See :ref:`python_api_extracts`.
:param conversions: Dictionary specifying SQL conversion functions to be applied to the data while it
is being inserted, for example ``{"name": "upper(?)"}``. See :ref:`python_api_conversions`.
:param columns: Dictionary over-riding the detected types used for the columns, for example
``{"age": int, "weight": float}``.
:param strict: Boolean, apply STRICT mode if creating the table.
"""
return self.insert_all(
[record],
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
alter=alter,
ignore=ignore,
replace=replace,
extracts=extracts,
conversions=conversions,
columns=columns,
strict=strict,
)
def insert_all(
self,
records,
pk=DEFAULT,
foreign_keys=DEFAULT,
column_order=DEFAULT,
not_null=DEFAULT,
defaults=DEFAULT,
batch_size=DEFAULT,
hash_id=DEFAULT,
hash_id_columns=DEFAULT,
alter=DEFAULT,
ignore=DEFAULT,
replace=DEFAULT,
truncate=False,
extracts=DEFAULT,
conversions=DEFAULT,
columns=DEFAULT,
upsert=False,
analyze=False,
strict=DEFAULT,
) -> "Table":
"""
Like ``.insert()`` but takes a list of records and ensures that the table
that it creates (if table does not exist) has columns for ALL of that data.
Use ``analyze=True`` to run ``ANALYZE`` after the insert has completed.
"""
pk = self.value_or_default("pk", pk)
foreign_keys = self.value_or_default("foreign_keys", foreign_keys)
column_order = self.value_or_default("column_order", column_order)
not_null = self.value_or_default("not_null", not_null)
defaults = self.value_or_default("defaults", defaults)
batch_size = self.value_or_default("batch_size", batch_size)
hash_id = self.value_or_default("hash_id", hash_id)
hash_id_columns = self.value_or_default("hash_id_columns", hash_id_columns)
alter = self.value_or_default("alter", alter)
ignore = self.value_or_default("ignore", ignore)
replace = self.value_or_default("replace", replace)
extracts = self.value_or_default("extracts", extracts)
conversions = self.value_or_default("conversions", conversions) or {}
columns = self.value_or_default("columns", columns)
strict = self.value_or_default("strict", strict)
if hash_id_columns and hash_id is None:
hash_id = "id"
if upsert and (not pk and not hash_id):
raise PrimaryKeyRequired("upsert() requires a pk")
assert not (hash_id and pk), "Use either pk= or hash_id="
if hash_id_columns and (hash_id is None):
hash_id = "id"
if hash_id:
pk = hash_id
assert not (
ignore and replace
), "Use either ignore=True or replace=True, not both"
all_columns = []
first = True
num_records_processed = 0
# Fix up any records with square braces in the column names
records = fix_square_braces(records)
# We can only handle a max of 999 variables in a SQL insert, so
# we need to adjust the batch_size down if we have too many cols
records = iter(records)
# Peek at first record to count its columns:
try:
first_record = next(records)
except StopIteration:
return self # It was an empty list
num_columns = len(first_record.keys())
assert (
num_columns <= SQLITE_MAX_VARS
), "Rows can have a maximum of {} columns".format(SQLITE_MAX_VARS)
batch_size = max(1, min(batch_size, SQLITE_MAX_VARS // num_columns))
self.last_rowid = None
self.last_pk = None
if truncate and self.exists():
self.db.execute("DELETE FROM [{}];".format(self.name))
result = None
for chunk in chunks(itertools.chain([first_record], records), batch_size):
chunk = list(chunk)
num_records_processed += len(chunk)
if first:
if not self.exists():
# Use the first batch to derive the table names
column_types = suggest_column_types(chunk)
if extracts:
for col in extracts:
if col in column_types:
column_types[col] = (
int # This will be an integer foreign key
)
column_types.update(columns or {})
self.create(
column_types,
pk,
foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
extracts=extracts,
strict=strict,
)
all_columns_set = set()
for record in chunk:
all_columns_set.update(record.keys())
all_columns = list(sorted(all_columns_set))
if hash_id:
all_columns.insert(0, hash_id)
else:
for record in chunk:
all_columns += [
column for column in record if column not in all_columns
]
first = False
result = self.insert_chunk(
alter,
extracts,
chunk,
all_columns,
hash_id,
hash_id_columns,
upsert,
pk,
not_null,
conversions,
num_records_processed,
replace,
ignore,
)
# If we only handled a single row populate self.last_pk
if num_records_processed == 1:
# For an insert we need to use result.lastrowid
if not upsert:
self.last_rowid = result.lastrowid
if (hash_id or pk) and self.last_rowid:
# Set self.last_pk to the pk(s) for that rowid
row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0]
if hash_id:
self.last_pk = row[hash_id]
elif isinstance(pk, str):
self.last_pk = row[pk]
else:
self.last_pk = tuple(row[p] for p in pk)
else:
self.last_pk = self.last_rowid
else:
# For an upsert use first_record from earlier
if hash_id:
self.last_pk = hash_record(first_record, hash_id_columns)
else:
self.last_pk = (
first_record[pk]
if isinstance(pk, str)
else tuple(first_record[p] for p in pk)
)
if analyze:
self.analyze()
return self
def upsert(
self,
record,
pk=DEFAULT,
foreign_keys=DEFAULT,
column_order=DEFAULT,
not_null=DEFAULT,
defaults=DEFAULT,
hash_id=DEFAULT,
hash_id_columns=DEFAULT,
alter=DEFAULT,
extracts=DEFAULT,
conversions=DEFAULT,
columns=DEFAULT,
strict=DEFAULT,
) -> "Table":
"""
Like ``.insert()`` but performs an ``UPSERT``, where records are inserted if they do
not exist and updated if they DO exist, based on matching against their primary key.
See :ref:`python_api_upsert`.
"""
return self.upsert_all(
[record],
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
alter=alter,
extracts=extracts,
conversions=conversions,
columns=columns,
strict=strict,
)
def upsert_all(
self,
records,
pk=DEFAULT,
foreign_keys=DEFAULT,
column_order=DEFAULT,
not_null=DEFAULT,
defaults=DEFAULT,
batch_size=DEFAULT,
hash_id=DEFAULT,
hash_id_columns=DEFAULT,
alter=DEFAULT,
extracts=DEFAULT,
conversions=DEFAULT,
columns=DEFAULT,
analyze=False,
strict=DEFAULT,
) -> "Table":
"""
Like ``.upsert()`` but can be applied to a list of records.
"""
return self.insert_all(
records,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
batch_size=batch_size,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
alter=alter,
extracts=extracts,
conversions=conversions,
columns=columns,
upsert=True,
analyze=analyze,
strict=strict,
)
def add_missing_columns(self, records: Iterable[Dict[str, Any]]) -> "Table":
needed_columns = suggest_column_types(records)
current_columns = {c.lower() for c in self.columns_dict}
for col_name, col_type in needed_columns.items():
if col_name.lower() not in current_columns:
self.add_column(col_name, col_type)
return self
def lookup(
self,
lookup_values: Dict[str, Any],
extra_values: Optional[Dict[str, Any]] = None,
pk: Optional[str] = "id",
foreign_keys: Optional[ForeignKeysType] = None,
column_order: Optional[List[str]] = None,
not_null: Optional[Iterable[str]] = None,
defaults: Optional[Dict[str, Any]] = None,
extracts: Optional[Union[Dict[str, str], List[str]]] = None,
conversions: Optional[Dict[str, str]] = None,
columns: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = False,
):
"""
Create or populate a lookup table with the specified values.
``db["Species"].lookup({"name": "Palm"})`` will create a table called ``Species``
(if one does not already exist) with two columns: ``id`` and ``name``. It will
set up a unique constraint on the ``name`` column to guarantee it will not
contain duplicate rows.
It will then insert a new row with the ``name`` set to ``Palm`` and return the
new integer primary key value.
An optional second argument can be provided with more ``name: value`` pairs to
be included only if the record is being created for the first time. These will
be ignored on subsequent lookup calls for records that already exist.
All other keyword arguments are passed through to ``.insert()``.
See :ref:`python_api_lookup_tables` for more details.
:param lookup_values: Dictionary specifying column names and values to use for the lookup
:param extra_values: Additional column values to be used only if creating a new record
:param strict: Boolean, apply STRICT mode if creating the table.
"""
assert isinstance(lookup_values, dict)
if extra_values is not None:
assert isinstance(extra_values, dict)
combined_values = dict(lookup_values)
if extra_values is not None:
combined_values.update(extra_values)
if self.exists():
self.add_missing_columns([combined_values])
unique_column_sets = [set(i.columns) for i in self.indexes]
if set(lookup_values.keys()) not in unique_column_sets:
self.create_index(lookup_values.keys(), unique=True)
wheres = ["[{}] = ?".format(column) for column in lookup_values]
rows = list(
self.rows_where(
" and ".join(wheres), [value for _, value in lookup_values.items()]
)
)
try:
return rows[0][pk]
except IndexError:
return self.insert(
combined_values,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
extracts=extracts,
conversions=conversions,
columns=columns,
strict=strict,
).last_pk
else:
pk = self.insert(
combined_values,
pk=pk,
foreign_keys=foreign_keys,
column_order=column_order,
not_null=not_null,
defaults=defaults,
extracts=extracts,
conversions=conversions,
columns=columns,
strict=strict,
).last_pk
self.create_index(lookup_values.keys(), unique=True)
return pk
def m2m(
self,
other_table: Union[str, "Table"],
record_or_iterable: Optional[
Union[Iterable[Dict[str, Any]], Dict[str, Any]]
] = None,
pk: Optional[Union[Any, Default]] = DEFAULT,
lookup: Optional[Dict[str, Any]] = None,
m2m_table: Optional[str] = None,
alter: bool = False,
):
"""
After inserting a record in a table, create one or more records in some other
table and then create many-to-many records linking the original record and the
newly created records together.
For example::
db["dogs"].insert({"id": 1, "name": "Cleo"}, pk="id").m2m(
"humans", {"id": 1, "name": "Natalie"}, pk="id"
)
See :ref:`python_api_m2m` for details.
:param other_table: The name of the table to insert the new records into.
:param record_or_iterable: A single dictionary record to insert, or a list of records.
:param pk: The primary key to use if creating ``other_table``.
:param lookup: Same dictionary as for ``.lookup()``, to create a many-to-many lookup table.
:param m2m_table: The string name to use for the many-to-many table, defaults to creating
this automatically based on the names of the two tables.
:param alter: Set to ``True`` to add any missing columns on ``other_table`` if that table
already exists.
"""
if isinstance(other_table, str):
other_table = cast(Table, self.db.table(other_table, pk=pk))
our_id = self.last_pk
if lookup is not None:
assert record_or_iterable is None, "Provide lookup= or record, not both"
else:
assert record_or_iterable is not None, "Provide lookup= or record, not both"
tables = list(sorted([self.name, other_table.name]))
columns = ["{}_id".format(t) for t in tables]
if m2m_table is not None:
m2m_table_name = m2m_table
else:
# Detect if there is a single, unambiguous option
candidates = self.db.m2m_table_candidates(self.name, other_table.name)
if len(candidates) == 1:
m2m_table_name = candidates[0]
elif len(candidates) > 1:
raise NoObviousTable(
"No single obvious m2m table for {}, {} - use m2m_table= parameter".format(
self.name, other_table.name
)
)
else:
# If not, create a new table
m2m_table_name = m2m_table or "{}_{}".format(*tables)
m2m_table_obj = self.db.table(m2m_table_name, pk=columns, foreign_keys=columns)
if lookup is None:
# if records is only one record, put the record in a list
if isinstance(record_or_iterable, Mapping):
records = [record_or_iterable]
else:
records = cast(List, record_or_iterable)
# Ensure each record exists in other table
for record in records:
id = other_table.insert(
cast(dict, record), pk=pk, replace=True, alter=alter
).last_pk
m2m_table_obj.insert(
{
"{}_id".format(other_table.name): id,
"{}_id".format(self.name): our_id,
},
replace=True,
)
else:
id = other_table.lookup(lookup)
m2m_table_obj.insert(
{
"{}_id".format(other_table.name): id,
"{}_id".format(self.name): our_id,
},
replace=True,
)
return self
def analyze(self):
"Run ANALYZE against this table"
self.db.analyze(self.name)
def analyze_column(
self,
column: str,
common_limit: int = 10,
value_truncate=None,
total_rows=None,
most_common: bool = True,
least_common: bool = True,
) -> "ColumnDetails":
"""
Return statistics about the specified column.
See :ref:`python_api_analyze_column`.
:param column: Column to analyze
:param common_limit: Show this many column values
:param value_truncate: Truncate display of common values to this many characters
:param total_rows: Optimization - pass the total number of rows in the table to save running a fresh ``count(*)`` query
:param most_common: If ``True``, calculate the most common values
:param least_common: If ``True``, calculate the least common values
"""
db = self.db
table = self.name
if total_rows is None:
total_rows = db[table].count
def truncate(value):
if value_truncate is None or isinstance(value, (float, int)):
return value
value = str(value)
if len(value) > value_truncate:
value = value[:value_truncate] + "..."
return value
num_null = db.execute(
"select count(*) from [{}] where [{}] is null".format(table, column)
).fetchone()[0]
num_blank = db.execute(
"select count(*) from [{}] where [{}] = ''".format(table, column)
).fetchone()[0]
num_distinct = db.execute(
"select count(distinct [{}]) from [{}]".format(column, table)
).fetchone()[0]
most_common_results = None
least_common_results = None
if num_distinct == 1:
value = db.execute(
"select [{}] from [{}] limit 1".format(column, table)
).fetchone()[0]
most_common_results = [(truncate(value), total_rows)]
elif num_distinct != total_rows:
if most_common:
# Optimization - if all rows are null, don't run this query
if num_null == total_rows:
most_common_results = [(None, total_rows)]
else:
most_common_results = [
(truncate(r[0]), r[1])
for r in db.execute(
"select [{}], count(*) from [{}] group by [{}] order by count(*) desc, [{}] limit {}".format(
column, table, column, column, common_limit
)
).fetchall()
]
most_common_results.sort(key=lambda p: (p[1], p[0]), reverse=True)
if least_common:
if num_distinct <= common_limit:
# No need to run the query if it will just return the results in reverse order
least_common_results = None
else:
least_common_results = [
(truncate(r[0]), r[1])
for r in db.execute(
"select [{}], count(*) from [{}] group by [{}] order by count(*), [{}] desc limit {}".format(
column, table, column, column, common_limit
)
).fetchall()
]
least_common_results.sort(key=lambda p: (p[1], p[0]))
return ColumnDetails(
self.name,
column,
total_rows,
num_null,
num_blank,
num_distinct,
most_common_results,
least_common_results,
)
def add_geometry_column(
self,
column_name: str,
geometry_type: str,
srid: int = 4326,
coord_dimension: str = "XY",
not_null: bool = False,
) -> bool:
"""
In SpatiaLite, a geometry column can only be added to an existing table.
To do so, use ``table.add_geometry_column``, passing in a geometry type.
By default, this will add a nullable column using
`SRID 4326 <https://spatialreference.org/ref/epsg/wgs-84/>`__. This can
be customized using the ``column_name``, ``srid`` and ``not_null`` arguments.
Returns ``True`` if the column was successfully added, ``False`` if not.
.. code-block:: python
from sqlite_utils.db import Database
from sqlite_utils.utils import find_spatialite
db = Database("mydb.db")
db.init_spatialite(find_spatialite())
# the table must exist before adding a geometry column
table = db["locations"].create({"name": str})
table.add_geometry_column("geometry", "POINT")
:param column_name: Name of column to add
:param geometry_type: Type of geometry column, for example ``"GEOMETRY"`` or ``"POINT" or ``"POLYGON"``
:param srid: Integer SRID, defaults to 4326 for WGS84
:param coord_dimension: Dimensions to use, defaults to ``"XY"`` - set to ``"XYZ"`` to work in three dimensions
:param not_null: Should the column be ``NOT NULL``
"""
cursor = self.db.execute(
"SELECT AddGeometryColumn(?, ?, ?, ?, ?, ?);",
[
self.name,
column_name,
srid,
geometry_type,
coord_dimension,
int(not_null),
],
)
result = cursor.fetchone()
return result and bool(result[0])
def create_spatial_index(self, column_name) -> bool:
"""
A spatial index allows for significantly faster bounding box queries.
To create one, use ``create_spatial_index`` with the name of an existing geometry column.
Returns ``True`` if the index was successfully created, ``False`` if not. Calling this
function if an index already exists is a no-op.
.. code-block:: python
# assuming SpatiaLite is loaded, create the table, add the column
table = db["locations"].create({"name": str})
table.add_geometry_column("geometry", "POINT")
# now we can index it
table.create_spatial_index("geometry")
# the spatial index is a virtual table, which we can inspect
print(db["idx_locations_geometry"].schema)
# outputs:
# CREATE VIRTUAL TABLE "idx_locations_geometry" USING rtree(pkid, xmin, xmax, ymin, ymax)
:param column_name: Geometry column to create the spatial index against
"""
if f"idx_{self.name}_{column_name}" in self.db.table_names():
return False
cursor = self.db.execute(
"select CreateSpatialIndex(?, ?)", [self.name, column_name]
)
result = cursor.fetchone()
return result and bool(result[0])
class View(Queryable):
def exists(self):
return True
def __repr__(self) -> str:
return "<View {} ({})>".format(
self.name, ", ".join(c.name for c in self.columns)
)
def drop(self, ignore=False):
"""
Drop this view.
:param ignore: Set to ``True`` to ignore the error if the view does not exist
"""
try:
self.db.execute("DROP VIEW [{}]".format(self.name))
except sqlite3.OperationalError:
if not ignore:
raise
def enable_fts(self, *args, **kwargs):
"``enable_fts()`` is supported on tables but not on views."
raise NotImplementedError(
"enable_fts() is supported on tables but not on views"
)
def jsonify_if_needed(value):
if isinstance(value, decimal.Decimal):
return float(value)
if isinstance(value, (dict, list, tuple)):
return json.dumps(value, default=repr, ensure_ascii=False)
elif isinstance(value, (datetime.time, datetime.date, datetime.datetime)):
return value.isoformat()
elif isinstance(value, datetime.timedelta):
return str(value)
elif isinstance(value, uuid.UUID):
return str(value)
else:
return value
def resolve_extracts(
extracts: Optional[Union[Dict[str, str], List[str], Tuple[str]]],
) -> dict:
if extracts is None:
extracts = {}
if isinstance(extracts, (list, tuple)):
extracts = {item: item for item in extracts}
return extracts
def validate_column_names(columns):
# Validate no columns contain '[' or ']' - #86
for column in columns:
assert (
"[" not in column and "]" not in column
), "'[' and ']' cannot be used in column names"
def fix_square_braces(records: Iterable[Dict[str, Any]]):
for record in records:
if any("[" in key or "]" in key for key in record.keys()):
yield {
key.replace("[", "_").replace("]", "_"): value
for key, value in record.items()
}
else:
yield record
def _decode_default_value(value):
if value.startswith("'") and value.endswith("'"):
# It's a string
return value[1:-1]
if value.isdigit():
# It's an integer
return int(value)
if value.startswith("X'") and value.endswith("'"):
# It's a binary string, stored as hex
to_decode = value[2:-1]
return binascii.unhexlify(to_decode)
# If it is a string containing a floating point number:
try:
return float(value)
except ValueError:
pass
return value
</document_content>
</document>
<document index="8">
<source>./sqlite_utils/hookspecs.py</source>
<document_content>
from pluggy import HookimplMarker
from pluggy import HookspecMarker
hookspec = HookspecMarker("sqlite_utils")
hookimpl = HookimplMarker("sqlite_utils")
@hookspec
def register_commands(cli):
"""Register additional CLI commands, e.g. 'sqlite-utils mycommand ...'"""
@hookspec
def prepare_connection(conn):
"""Modify SQLite connection in some way e.g. register custom SQL functions"""
</document_content>
</document>
<document index="9">
<source>./sqlite_utils/plugins.py</source>
<document_content>
import pluggy
import sys
from . import hookspecs
pm = pluggy.PluginManager("sqlite_utils")
pm.add_hookspecs(hookspecs)
if not getattr(sys, "_called_from_test", False):
# Only load plugins if not running tests
pm.load_setuptools_entrypoints("sqlite_utils")
def get_plugins():
plugins = []
plugin_to_distinfo = dict(pm.list_plugin_distinfo())
for plugin in pm.get_plugins():
plugin_info = {
"name": plugin.__name__,
"hooks": [h.name for h in pm.get_hookcallers(plugin)],
}
distinfo = plugin_to_distinfo.get(plugin)
if distinfo:
plugin_info["version"] = distinfo.version
plugin_info["name"] = distinfo.project_name
plugins.append(plugin_info)
return plugins
</document_content>
</document>
<document index="10">
<source>./sqlite_utils/recipes.py</source>
<document_content>
from dateutil import parser
import json
IGNORE = object()
SET_NULL = object()
def parsedate(value, dayfirst=False, yearfirst=False, errors=None):
"""
Parse a date and convert it to ISO date format: yyyy-mm-dd
\b
- dayfirst=True: treat xx as the day in xx/yy/zz
- yearfirst=True: treat xx as the year in xx/yy/zz
- errors=r.IGNORE to ignore values that cannot be parsed
- errors=r.SET_NULL to set values that cannot be parsed to null
"""
try:
return (
parser.parse(value, dayfirst=dayfirst, yearfirst=yearfirst)
.date()
.isoformat()
)
except parser.ParserError:
if errors is IGNORE:
return value
elif errors is SET_NULL:
return None
else:
raise
def parsedatetime(value, dayfirst=False, yearfirst=False, errors=None):
"""
Parse a datetime and convert it to ISO datetime format: yyyy-mm-ddTHH:MM:SS
\b
- dayfirst=True: treat xx as the day in xx/yy/zz
- yearfirst=True: treat xx as the year in xx/yy/zz
- errors=r.IGNORE to ignore values that cannot be parsed
- errors=r.SET_NULL to set values that cannot be parsed to null
"""
try:
return parser.parse(value, dayfirst=dayfirst, yearfirst=yearfirst).isoformat()
except parser.ParserError:
if errors is IGNORE:
return value
elif errors is SET_NULL:
return None
else:
raise
def jsonsplit(value, delimiter=",", type=str):
"""
Convert a string like a,b,c into a JSON array ["a", "b", "c"]
"""
return json.dumps([type(s.strip()) for s in value.split(delimiter)])
</document_content>
</document>
<document index="11">
<source>./sqlite_utils/utils.py</source>
<document_content>
import base64
import contextlib
import csv
import enum
import hashlib
import io
import itertools
import json
import os
import sys
from . import recipes
from typing import Dict, cast, BinaryIO, Iterable, Optional, Tuple, Type
import click
try:
import pysqlite3 as sqlite3 # noqa: F401
from pysqlite3 import dbapi2 # noqa: F401
OperationalError = dbapi2.OperationalError
except ImportError:
try:
import sqlean as sqlite3 # noqa: F401
from sqlean import dbapi2 # noqa: F401
OperationalError = dbapi2.OperationalError
except ImportError:
import sqlite3 # noqa: F401
from sqlite3 import dbapi2 # noqa: F401
OperationalError = dbapi2.OperationalError
SPATIALITE_PATHS = (
"/usr/lib/x86_64-linux-gnu/mod_spatialite.so",
"/usr/lib/aarch64-linux-gnu/mod_spatialite.so",
"/usr/local/lib/mod_spatialite.dylib",
"/usr/local/lib/mod_spatialite.so",
"/opt/homebrew/lib/mod_spatialite.dylib",
)
# Mainly so we can restore it if needed in the tests:
ORIGINAL_CSV_FIELD_SIZE_LIMIT = csv.field_size_limit()
def maximize_csv_field_size_limit():
"""
Increase the CSV field size limit to the maximum possible.
"""
# https://stackoverflow.com/a/15063941
field_size_limit = sys.maxsize
while True:
try:
csv.field_size_limit(field_size_limit)
break
except OverflowError:
field_size_limit = int(field_size_limit / 10)
def find_spatialite() -> Optional[str]:
"""
The ``find_spatialite()`` function searches for the `SpatiaLite <https://www.gaia-gis.it/fossil/libspatialite/index>`__
SQLite extension in some common places. It returns a string path to the location, or ``None`` if SpatiaLite was not found.
You can use it in code like this:
.. code-block:: python
from sqlite_utils import Database
from sqlite_utils.utils import find_spatialite
db = Database("mydb.db")
spatialite = find_spatialite()
if spatialite:
db.conn.enable_load_extension(True)
db.conn.load_extension(spatialite)
# or use with db.init_spatialite like this
db.init_spatialite(find_spatialite())
"""
for path in SPATIALITE_PATHS:
if os.path.exists(path):
return path
return None
def suggest_column_types(records):
all_column_types = {}
for record in records:
for key, value in record.items():
all_column_types.setdefault(key, set()).add(type(value))
return types_for_column_types(all_column_types)
def types_for_column_types(all_column_types):
column_types = {}
for key, types in all_column_types.items():
# Ignore null values if at least one other type present:
if len(types) > 1:
types.discard(None.__class__)
if {None.__class__} == types:
t = str
elif len(types) == 1:
t = list(types)[0]
# But if it's a subclass of list / tuple / dict, use str
# instead as we will be storing it as JSON in the table
for superclass in (list, tuple, dict):
if issubclass(t, superclass):
t = str
elif {int, bool}.issuperset(types):
t = int
elif {int, float, bool}.issuperset(types):
t = float
elif {bytes, str}.issuperset(types):
t = bytes
else:
t = str
column_types[key] = t
return column_types
def column_affinity(column_type):
# Implementation of SQLite affinity rules from
# https://www.sqlite.org/datatype3.html#determination_of_column_affinity
assert isinstance(column_type, str)
column_type = column_type.upper().strip()
if column_type == "":
return str # We differ from spec, which says it should be BLOB
if "INT" in column_type:
return int
if "CHAR" in column_type or "CLOB" in column_type or "TEXT" in column_type:
return str
if "BLOB" in column_type:
return bytes
if "REAL" in column_type or "FLOA" in column_type or "DOUB" in column_type:
return float
# Default is 'NUMERIC', which we currently also treat as float
return float
def decode_base64_values(doc):
# Looks for '{"$base64": true..., "encoded": ...}' values and decodes them
to_fix = [
k
for k in doc
if isinstance(doc[k], dict)
and doc[k].get("$base64") is True
and "encoded" in doc[k]
]
if not to_fix:
return doc
return dict(doc, **{k: base64.b64decode(doc[k]["encoded"]) for k in to_fix})
class UpdateWrapper:
def __init__(self, wrapped, update):
self._wrapped = wrapped
self._update = update
def __iter__(self):
for line in self._wrapped:
self._update(len(line))
yield line
def read(self, size=-1):
data = self._wrapped.read(size)
self._update(len(data))
return data
@contextlib.contextmanager
def file_progress(file, silent=False, **kwargs):
if silent:
yield file
return
# file.fileno() throws an exception in our test suite
try:
fileno = file.fileno()
except io.UnsupportedOperation:
yield file
return
if fileno == 0: # 0 means stdin
yield file
else:
file_length = os.path.getsize(file.name)
with click.progressbar(length=file_length, **kwargs) as bar:
yield UpdateWrapper(file, bar.update)
class Format(enum.Enum):
CSV = 1
TSV = 2
JSON = 3
NL = 4
class RowsFromFileError(Exception):
pass
class RowsFromFileBadJSON(RowsFromFileError):
pass
class RowError(Exception):
pass
def _extra_key_strategy(
reader: Iterable[dict],
ignore_extras: Optional[bool] = False,
extras_key: Optional[str] = None,
) -> Iterable[dict]:
# Logic for handling CSV rows with more values than there are headings
for row in reader:
# DictReader adds a 'None' key with extra row values
if None not in row:
yield row
elif ignore_extras:
# ignoring row.pop(none) because of this issue:
# https://github.com/simonw/sqlite-utils/issues/440#issuecomment-1155358637
row.pop(None) # type: ignore
yield row
elif not extras_key:
extras = row.pop(None) # type: ignore
raise RowError(
"Row {} contained these extra values: {}".format(row, extras)
)
else:
row[extras_key] = row.pop(None) # type: ignore
yield row
def rows_from_file(
fp: BinaryIO,
format: Optional[Format] = None,
dialect: Optional[Type[csv.Dialect]] = None,
encoding: Optional[str] = None,
ignore_extras: Optional[bool] = False,
extras_key: Optional[str] = None,
) -> Tuple[Iterable[dict], Format]:
"""
Load a sequence of dictionaries from a file-like object containing one of four different formats.
.. code-block:: python
from sqlite_utils.utils import rows_from_file
import io
rows, format = rows_from_file(io.StringIO("id,name\\n1,Cleo")))
print(list(rows), format)
# Outputs [{'id': '1', 'name': 'Cleo'}] Format.CSV
This defaults to attempting to automatically detect the format of the data, or you can pass in an
explicit format using the format= option.
Returns a tuple of ``(rows_generator, format_used)`` where ``rows_generator`` can be iterated over
to return dictionaries, while ``format_used`` is a value from the ``sqlite_utils.utils.Format`` enum:
.. code-block:: python
class Format(enum.Enum):
CSV = 1
TSV = 2
JSON = 3
NL = 4
If a CSV or TSV file includes rows with more fields than are declared in the header a
``sqlite_utils.utils.RowError`` exception will be raised when you loop over the generator.
You can instead ignore the extra data by passing ``ignore_extras=True``.
Or pass ``extras_key="rest"`` to put those additional values in a list in a key called ``rest``.
:param fp: a file-like object containing binary data
:param format: the format to use - omit this to detect the format
:param dialect: the CSV dialect to use - omit this to detect the dialect
:param encoding: the character encoding to use when reading CSV/TSV data
:param ignore_extras: ignore any extra fields on rows
:param extras_key: put any extra fields in a list with this key
"""
if ignore_extras and extras_key:
raise ValueError("Cannot use ignore_extras= and extras_key= together")
if format == Format.JSON:
decoded = json.load(fp)
if isinstance(decoded, dict):
decoded = [decoded]
if not isinstance(decoded, list):
raise RowsFromFileBadJSON("JSON must be a list or a dictionary")
return decoded, Format.JSON
elif format == Format.NL:
return (json.loads(line) for line in fp if line.strip()), Format.NL
elif format == Format.CSV:
use_encoding: str = encoding or "utf-8-sig"
decoded_fp = io.TextIOWrapper(fp, encoding=use_encoding)
if dialect is not None:
reader = csv.DictReader(decoded_fp, dialect=dialect)
else:
reader = csv.DictReader(decoded_fp)
return _extra_key_strategy(reader, ignore_extras, extras_key), Format.CSV
elif format == Format.TSV:
rows = rows_from_file(
fp, format=Format.CSV, dialect=csv.excel_tab, encoding=encoding
)[0]
return _extra_key_strategy(rows, ignore_extras, extras_key), Format.TSV
elif format is None:
# Detect the format, then call this recursively
buffered = io.BufferedReader(cast(io.RawIOBase, fp), buffer_size=4096)
try:
first_bytes = buffered.peek(2048).strip()
except AttributeError:
# Likely the user passed a TextIO when this needs a BytesIO
raise TypeError(
"rows_from_file() requires a file-like object that supports peek(), such as io.BytesIO"
)
if first_bytes.startswith(b"[") or first_bytes.startswith(b"{"):
# TODO: Detect newline-JSON
return rows_from_file(buffered, format=Format.JSON)
else:
dialect = csv.Sniffer().sniff(
first_bytes.decode(encoding or "utf-8-sig", "ignore")
)
rows, _ = rows_from_file(
buffered, format=Format.CSV, dialect=dialect, encoding=encoding
)
# Make sure we return the format we detected
format = Format.TSV if dialect.delimiter == "\t" else Format.CSV
return _extra_key_strategy(rows, ignore_extras, extras_key), format
else:
raise RowsFromFileError("Bad format")
class TypeTracker:
"""
Wrap an iterator of dictionaries and keep track of which SQLite column
types are the most likely fit for each of their keys.
Example usage:
.. code-block:: python
from sqlite_utils.utils import TypeTracker
import sqlite_utils
db = sqlite_utils.Database(memory=True)
tracker = TypeTracker()
rows = [{"id": "1", "name": "Cleo", "id": "2", "name": "Cardi"}]
db["creatures"].insert_all(tracker.wrap(rows))
print(tracker.types)
# Outputs {'id': 'integer', 'name': 'text'}
db["creatures"].transform(types=tracker.types)
"""
def __init__(self):
self.trackers = {}
def wrap(self, iterator: Iterable[dict]) -> Iterable[dict]:
"""
Use this to loop through an existing iterator, tracking the column types
as part of the iteration.
:param iterator: The iterator to wrap
"""
for row in iterator:
for key, value in row.items():
tracker = self.trackers.setdefault(key, ValueTracker())
tracker.evaluate(value)
yield row
@property
def types(self) -> Dict[str, str]:
"""
A dictionary mapping column names to their detected types. This can be passed
to the ``db[table_name].transform(types=tracker.types)`` method.
"""
return {key: tracker.guessed_type for key, tracker in self.trackers.items()}
class ValueTracker:
def __init__(self):
self.couldbe = {key: getattr(self, "test_" + key) for key in self.get_tests()}
@classmethod
def get_tests(cls):
return [
key.split("test_")[-1]
for key in cls.__dict__.keys()
if key.startswith("test_")
]
def test_integer(self, value):
try:
int(value)
return True
except (ValueError, TypeError):
return False
def test_float(self, value):
try:
float(value)
return True
except (ValueError, TypeError):
return False
def __repr__(self) -> str:
return self.guessed_type + ": possibilities = " + repr(self.couldbe)
@property
def guessed_type(self):
options = set(self.couldbe.keys())
# Return based on precedence
for key in self.get_tests():
if key in options:
return key
return "text"
def evaluate(self, value):
if not value or not self.couldbe:
return
not_these = []
for name, test in self.couldbe.items():
if not test(value):
not_these.append(name)
for key in not_these:
del self.couldbe[key]
class NullProgressBar:
def __init__(self, *args):
self.args = args
def __iter__(self):
yield from self.args[0]
def update(self, value):
pass
@contextlib.contextmanager
def progressbar(*args, **kwargs):
silent = kwargs.pop("silent")
if silent:
yield NullProgressBar(*args)
else:
with click.progressbar(*args, **kwargs) as bar:
yield bar
def _compile_code(code, imports, variable="value"):
globals = {"r": recipes, "recipes": recipes}
# If user defined a convert() function, return that
try:
exec(code, globals)
return globals["convert"]
except (AttributeError, SyntaxError, NameError, KeyError, TypeError):
pass
# Try compiling their code as a function instead
body_variants = [code]
# If single line and no 'return', try adding the return
if "\n" not in code and not code.strip().startswith("return "):
body_variants.insert(0, "return {}".format(code))
code_o = None
for variant in body_variants:
new_code = ["def fn({}):".format(variable)]
for line in variant.split("\n"):
new_code.append(" {}".format(line))
try:
code_o = compile("\n".join(new_code), "<string>", "exec")
break
except SyntaxError:
# Try another variant, e.g. for 'return row["column"] = 1'
continue
if code_o is None:
raise SyntaxError("Could not compile code")
for import_ in imports:
globals[import_.split(".")[0]] = __import__(import_)
exec(code_o, globals)
return globals["fn"]
def chunks(sequence: Iterable, size: int) -> Iterable[Iterable]:
"""
Iterate over chunks of the sequence of the given size.
:param sequence: Any Python iterator
:param size: The size of each chunk
"""
iterator = iter(sequence)
for item in iterator:
yield itertools.chain([item], itertools.islice(iterator, size - 1))
def hash_record(record: Dict, keys: Optional[Iterable[str]] = None):
"""
``record`` should be a Python dictionary. Returns a sha1 hash of the
keys and values in that record.
If ``keys=`` is provided, uses just those keys to generate the hash.
Example usage::
from sqlite_utils.utils import hash_record
hashed = hash_record({"name": "Cleo", "twitter": "CleoPaws"})
# Or with the keys= option:
hashed = hash_record(
{"name": "Cleo", "twitter": "CleoPaws", "age": 7},
keys=("name", "twitter")
)
:param record: Record to generate a hash for
:param keys: Subset of keys to use for that hash
"""
to_hash = record
if keys is not None:
to_hash = {key: record[key] for key in keys}
return hashlib.sha1(
json.dumps(to_hash, separators=(",", ":"), sort_keys=True, default=repr).encode(
"utf8"
)
).hexdigest()
def _flatten(d):
for key, value in d.items():
if isinstance(value, dict):
for key2, value2 in _flatten(value):
yield key + "_" + key2, value2
else:
yield key, value
def flatten(row: dict) -> dict:
"""
Turn a nested dict e.g. ``{"a": {"b": 1}}`` into a flat dict: ``{"a_b": 1}``
:param row: A Python dictionary, optionally with nested dictionaries
"""
return dict(_flatten(row))
</document_content>
</document>
<document index="12">
<source>./tests/__init__.py</source>
<document_content>
</document_content>
</document>
<document index="13">
<source>./tests/conftest.py</source>
<document_content>
from sqlite_utils import Database
from sqlite_utils.utils import sqlite3
import pytest
CREATE_TABLES = """
create table Gosh (c1 text, c2 text, c3 text);
create table Gosh2 (c1 text, c2 text, c3 text);
"""
def pytest_configure(config):
import sys
sys._called_from_test = True
@pytest.fixture
def fresh_db():
return Database(memory=True)
@pytest.fixture
def existing_db():
database = Database(memory=True)
database.executescript(
"""
CREATE TABLE foo (text TEXT);
INSERT INTO foo (text) values ("one");
INSERT INTO foo (text) values ("two");
INSERT INTO foo (text) values ("three");
"""
)
return database
@pytest.fixture
def db_path(tmpdir):
path = str(tmpdir / "test.db")
db = sqlite3.connect(path)
db.executescript(CREATE_TABLES)
return path
</document_content>
</document>
<document index="14">
<source>./tests/test_analyze.py</source>
<document_content>
import pytest
@pytest.fixture
def db(fresh_db):
fresh_db["one_index"].insert({"id": 1, "name": "Cleo"}, pk="id")
fresh_db["one_index"].create_index(["name"])
fresh_db["two_indexes"].insert({"id": 1, "name": "Cleo", "species": "dog"}, pk="id")
fresh_db["two_indexes"].create_index(["name"])
fresh_db["two_indexes"].create_index(["species"])
return fresh_db
def test_analyze_whole_database(db):
assert set(db.table_names()) == {"one_index", "two_indexes"}
db.analyze()
assert set(db.table_names()).issuperset(
{"one_index", "two_indexes", "sqlite_stat1"}
)
assert list(db["sqlite_stat1"].rows) == [
{"tbl": "two_indexes", "idx": "idx_two_indexes_species", "stat": "1 1"},
{"tbl": "two_indexes", "idx": "idx_two_indexes_name", "stat": "1 1"},
{"tbl": "one_index", "idx": "idx_one_index_name", "stat": "1 1"},
]
@pytest.mark.parametrize("method", ("db_method_with_name", "table_method"))
def test_analyze_one_table(db, method):
assert set(db.table_names()).issuperset({"one_index", "two_indexes"})
if method == "db_method_with_name":
db.analyze("one_index")
elif method == "table_method":
db["one_index"].analyze()
assert set(db.table_names()).issuperset(
{"one_index", "two_indexes", "sqlite_stat1"}
)
assert list(db["sqlite_stat1"].rows) == [
{"tbl": "one_index", "idx": "idx_one_index_name", "stat": "1 1"}
]
def test_analyze_index_by_name(db):
assert set(db.table_names()) == {"one_index", "two_indexes"}
db.analyze("idx_two_indexes_species")
assert set(db.table_names()).issuperset(
{"one_index", "two_indexes", "sqlite_stat1"}
)
assert list(db["sqlite_stat1"].rows) == [
{"tbl": "two_indexes", "idx": "idx_two_indexes_species", "stat": "1 1"},
]
</document_content>
</document>
<document index="15">
<source>./tests/test_analyze_tables.py</source>
<document_content>
from sqlite_utils.db import Database, ColumnDetails
from sqlite_utils import cli
from click.testing import CliRunner
import pytest
import sqlite3
@pytest.fixture
def db_to_analyze(fresh_db):
stuff = fresh_db["stuff"]
stuff.insert_all(
[
{"id": 1, "owner": "Terryterryterry", "size": 5},
{"id": 2, "owner": "Joan", "size": 4},
{"id": 3, "owner": "Kumar", "size": 5},
{"id": 4, "owner": "Anne", "size": 5},
{"id": 5, "owner": "Terryterryterry", "size": 5},
{"id": 6, "owner": "Joan", "size": 4},
{"id": 7, "owner": "Kumar", "size": 5},
{"id": 8, "owner": "Joan", "size": 4},
],
pk="id",
)
return fresh_db
@pytest.fixture
def big_db_to_analyze_path(tmpdir):
path = str(tmpdir / "test.db")
db = Database(path)
categories = {
"A": 40,
"B": 30,
"C": 20,
"D": 10,
}
to_insert = []
for category, count in categories.items():
for _ in range(count):
to_insert.append(
{
"category": category,
"all_null": None,
}
)
db["stuff"].insert_all(to_insert)
return path
@pytest.mark.parametrize(
"column,extra_kwargs,expected",
[
(
"id",
{},
ColumnDetails(
table="stuff",
column="id",
total_rows=8,
num_null=0,
num_blank=0,
num_distinct=8,
most_common=None,
least_common=None,
),
),
(
"owner",
{},
ColumnDetails(
table="stuff",
column="owner",
total_rows=8,
num_null=0,
num_blank=0,
num_distinct=4,
most_common=[("Joan", 3), ("Kumar", 2)],
least_common=[("Anne", 1), ("Terry...", 2)],
),
),
(
"size",
{},
ColumnDetails(
table="stuff",
column="size",
total_rows=8,
num_null=0,
num_blank=0,
num_distinct=2,
most_common=[(5, 5), (4, 3)],
least_common=None,
),
),
(
"owner",
{"most_common": False},
ColumnDetails(
table="stuff",
column="owner",
total_rows=8,
num_null=0,
num_blank=0,
num_distinct=4,
most_common=None,
least_common=[("Anne", 1), ("Terry...", 2)],
),
),
(
"owner",
{"least_common": False},
ColumnDetails(
table="stuff",
column="owner",
total_rows=8,
num_null=0,
num_blank=0,
num_distinct=4,
most_common=[("Joan", 3), ("Kumar", 2)],
least_common=None,
),
),
],
)
def test_analyze_column(db_to_analyze, column, extra_kwargs, expected):
assert (
db_to_analyze["stuff"].analyze_column(
column, common_limit=2, value_truncate=5, **extra_kwargs
)
== expected
)
@pytest.fixture
def db_to_analyze_path(db_to_analyze, tmpdir):
path = str(tmpdir / "test.db")
db = sqlite3.connect(path)
sql = "\n".join(db_to_analyze.iterdump())
db.executescript(sql)
return path
def test_analyze_table(db_to_analyze_path):
result = CliRunner().invoke(cli.cli, ["analyze-tables", db_to_analyze_path])
assert (
result.output.strip()
== (
"""
stuff.id: (1/3)
Total rows: 8
Null rows: 0
Blank rows: 0
Distinct values: 8
stuff.owner: (2/3)
Total rows: 8
Null rows: 0
Blank rows: 0
Distinct values: 4
Most common:
3: Joan
2: Terryterryterry
2: Kumar
1: Anne
stuff.size: (3/3)
Total rows: 8
Null rows: 0
Blank rows: 0
Distinct values: 2
Most common:
5: 5
3: 4"""
).strip()
)
def test_analyze_table_save(db_to_analyze_path):
result = CliRunner().invoke(
cli.cli, ["analyze-tables", db_to_analyze_path, "--save"]
)
assert result.exit_code == 0
rows = list(Database(db_to_analyze_path)["_analyze_tables_"].rows)
assert rows == [
{
"table": "stuff",
"column": "id",
"total_rows": 8,
"num_null": 0,
"num_blank": 0,
"num_distinct": 8,
"most_common": None,
"least_common": None,
},
{
"table": "stuff",
"column": "owner",
"total_rows": 8,
"num_null": 0,
"num_blank": 0,
"num_distinct": 4,
"most_common": '[["Joan", 3], ["Terryterryterry", 2], ["Kumar", 2], ["Anne", 1]]',
"least_common": None,
},
{
"table": "stuff",
"column": "size",
"total_rows": 8,
"num_null": 0,
"num_blank": 0,
"num_distinct": 2,
"most_common": "[[5, 5], [4, 3]]",
"least_common": None,
},
]
@pytest.mark.parametrize(
"no_most,no_least",
(
(False, False),
(True, False),
(False, True),
(True, True),
),
)
def test_analyze_table_save_no_most_no_least_options(
no_most, no_least, big_db_to_analyze_path
):
args = [
"analyze-tables",
big_db_to_analyze_path,
"--save",
"--common-limit",
"2",
"--column",
"category",
]
if no_most:
args.append("--no-most")
if no_least:
args.append("--no-least")
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0
rows = list(Database(big_db_to_analyze_path)["_analyze_tables_"].rows)
expected = {
"table": "stuff",
"column": "category",
"total_rows": 100,
"num_null": 0,
"num_blank": 0,
"num_distinct": 4,
"most_common": None,
"least_common": None,
}
if not no_most:
expected["most_common"] = '[["A", 40], ["B", 30]]'
if not no_least:
expected["least_common"] = '[["D", 10], ["C", 20]]'
assert rows == [expected]
def test_analyze_table_column_all_nulls(big_db_to_analyze_path):
result = CliRunner().invoke(
cli.cli,
["analyze-tables", big_db_to_analyze_path, "stuff", "--column", "all_null"],
)
assert result.exit_code == 0
assert result.output == (
"stuff.all_null: (1/1)\n\n Total rows: 100\n"
" Null rows: 100\n"
" Blank rows: 0\n"
"\n"
" Distinct values: 0\n\n"
)
@pytest.mark.parametrize(
"args,expected_error",
(
(["-c", "bad_column"], "These columns were not found: bad_column\n"),
(["one", "-c", "age"], "These columns were not found: age\n"),
(["two", "-c", "age"], None),
(
["one", "-c", "age", "--column", "bad"],
"These columns were not found: age, bad\n",
),
),
)
def test_analyze_table_validate_columns(tmpdir, args, expected_error):
path = str(tmpdir / "test_validate_columns.db")
db = Database(path)
db["one"].insert(
{
"id": 1,
"name": "one",
}
)
db["two"].insert(
{
"id": 1,
"age": 5,
}
)
result = CliRunner().invoke(
cli.cli,
["analyze-tables", path] + args,
catch_exceptions=False,
)
assert result.exit_code == (1 if expected_error else 0)
if expected_error:
assert expected_error in result.output
</document_content>
</document>
<document index="16">
<source>./tests/test_attach.py</source>
<document_content>
from sqlite_utils import Database
def test_attach(tmpdir):
foo_path = str(tmpdir / "foo.db")
bar_path = str(tmpdir / "bar.db")
db = Database(foo_path)
with db.conn:
db["foo"].insert({"id": 1, "text": "foo"})
db2 = Database(bar_path)
with db2.conn:
db2["bar"].insert({"id": 1, "text": "bar"})
db.attach("bar", bar_path)
assert db.execute(
"select * from foo union all select * from bar.bar"
).fetchall() == [(1, "foo"), (1, "bar")]
</document_content>
</document>
<document index="17">
<source>./tests/test_cli.py</source>
<document_content>
from sqlite_utils import cli, Database
from sqlite_utils.db import Index, ForeignKey
from click.testing import CliRunner
from pathlib import Path
import subprocess
import sys
from unittest import mock
import json
import os
import pytest
import textwrap
def write_json(file_path, data):
with open(file_path, "w") as fp:
json.dump(data, fp)
def _supports_pragma_function_list():
db = Database(memory=True)
try:
db.execute("select * from pragma_function_list()")
except Exception:
return False
return True
def _has_compiled_ext():
for ext in ["dylib", "so", "dll"]:
path = Path(__file__).parent / f"ext.{ext}"
if path.is_file():
return True
return False
COMPILED_EXTENSION_PATH = str(Path(__file__).parent / "ext")
@pytest.mark.parametrize(
"options",
(
["-h"],
["--help"],
["insert", "-h"],
["insert", "--help"],
),
)
def test_help(options):
result = CliRunner().invoke(cli.cli, options)
assert result.exit_code == 0
assert result.output.startswith("Usage: ")
assert "-h, --help" in result.output
def test_tables(db_path):
result = CliRunner().invoke(cli.cli, ["tables", db_path], catch_exceptions=False)
assert '[{"table": "Gosh"},\n {"table": "Gosh2"}]' == result.output.strip()
def test_views(db_path):
Database(db_path).create_view("hello", "select sqlite_version()")
result = CliRunner().invoke(cli.cli, ["views", db_path, "--table", "--schema"])
assert (
"view schema\n"
"------ --------------------------------------------\n"
"hello CREATE VIEW hello AS select sqlite_version()"
) == result.output.strip()
def test_tables_fts4(db_path):
Database(db_path)["Gosh"].enable_fts(["c2"], fts_version="FTS4")
result = CliRunner().invoke(cli.cli, ["tables", "--fts4", db_path])
assert '[{"table": "Gosh_fts"}]' == result.output.strip()
def test_tables_fts5(db_path):
Database(db_path)["Gosh"].enable_fts(["c2"], fts_version="FTS5")
result = CliRunner().invoke(cli.cli, ["tables", "--fts5", db_path])
assert '[{"table": "Gosh_fts"}]' == result.output.strip()
def test_tables_counts_and_columns(db_path):
db = Database(db_path)
with db.conn:
db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)])
result = CliRunner().invoke(cli.cli, ["tables", "--counts", "--columns", db_path])
assert (
'[{"table": "Gosh", "count": 0, "columns": ["c1", "c2", "c3"]},\n'
' {"table": "Gosh2", "count": 0, "columns": ["c1", "c2", "c3"]},\n'
' {"table": "lots", "count": 30, "columns": ["id", "age"]}]'
) == result.output.strip()
@pytest.mark.parametrize(
"format,expected",
[
(
"--csv",
(
"table,count,columns\n"
'Gosh,0,"c1\n'
"c2\n"
'c3"\n'
'Gosh2,0,"c1\n'
"c2\n"
'c3"\n'
'lots,30,"id\n'
'age"'
),
),
(
"--tsv",
"table\tcount\tcolumns\nGosh\t0\t['c1', 'c2', 'c3']\nGosh2\t0\t['c1', 'c2', 'c3']\nlots\t30\t['id', 'age']",
),
],
)
def test_tables_counts_and_columns_csv(db_path, format, expected):
db = Database(db_path)
with db.conn:
db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)])
result = CliRunner().invoke(
cli.cli, ["tables", "--counts", "--columns", format, db_path]
)
assert result.output.strip().replace("\r", "") == expected
def test_tables_schema(db_path):
db = Database(db_path)
with db.conn:
db["lots"].insert_all([{"id": i, "age": i + 1} for i in range(30)])
result = CliRunner().invoke(cli.cli, ["tables", "--schema", db_path])
assert (
'[{"table": "Gosh", "schema": "CREATE TABLE Gosh (c1 text, c2 text, c3 text)"},\n'
' {"table": "Gosh2", "schema": "CREATE TABLE Gosh2 (c1 text, c2 text, c3 text)"},\n'
' {"table": "lots", "schema": "CREATE TABLE [lots] (\\n [id] INTEGER,\\n [age] INTEGER\\n)"}]'
) == result.output.strip()
@pytest.mark.parametrize(
"options,expected",
[
(
["--fmt", "simple"],
(
"c1 c2 c3\n"
"----- ----- ----------\n"
"verb0 noun0 adjective0\n"
"verb1 noun1 adjective1\n"
"verb2 noun2 adjective2\n"
"verb3 noun3 adjective3"
),
),
(
["-t"],
(
"c1 c2 c3\n"
"----- ----- ----------\n"
"verb0 noun0 adjective0\n"
"verb1 noun1 adjective1\n"
"verb2 noun2 adjective2\n"
"verb3 noun3 adjective3"
),
),
(
["--fmt", "rst"],
(
"===== ===== ==========\n"
"c1 c2 c3\n"
"===== ===== ==========\n"
"verb0 noun0 adjective0\n"
"verb1 noun1 adjective1\n"
"verb2 noun2 adjective2\n"
"verb3 noun3 adjective3\n"
"===== ===== =========="
),
),
],
)
def test_output_table(db_path, options, expected):
db = Database(db_path)
with db.conn:
db["rows"].insert_all(
[
{
"c1": "verb{}".format(i),
"c2": "noun{}".format(i),
"c3": "adjective{}".format(i),
}
for i in range(4)
]
)
result = CliRunner().invoke(cli.cli, ["rows", db_path, "rows"] + options)
assert result.exit_code == 0
assert expected == result.output.strip()
def test_create_index(db_path):
db = Database(db_path)
assert [] == db["Gosh"].indexes
result = CliRunner().invoke(cli.cli, ["create-index", db_path, "Gosh", "c1"])
assert result.exit_code == 0
assert [
Index(
seq=0, name="idx_Gosh_c1", unique=0, origin="c", partial=0, columns=["c1"]
)
] == db["Gosh"].indexes
# Try with a custom name
result = CliRunner().invoke(
cli.cli, ["create-index", db_path, "Gosh", "c2", "--name", "blah"]
)
assert result.exit_code == 0
assert [
Index(seq=0, name="blah", unique=0, origin="c", partial=0, columns=["c2"]),
Index(
seq=1, name="idx_Gosh_c1", unique=0, origin="c", partial=0, columns=["c1"]
),
] == db["Gosh"].indexes
# Try a two-column unique index
create_index_unique_args = [
"create-index",
db_path,
"Gosh2",
"c1",
"c2",
"--unique",
]
result = CliRunner().invoke(cli.cli, create_index_unique_args)
assert result.exit_code == 0
assert [
Index(
seq=0,
name="idx_Gosh2_c1_c2",
unique=1,
origin="c",
partial=0,
columns=["c1", "c2"],
)
] == db["Gosh2"].indexes
# Trying to create the same index should fail
assert CliRunner().invoke(cli.cli, create_index_unique_args).exit_code != 0
# ... unless we use --if-not-exists or --ignore
for option in ("--if-not-exists", "--ignore"):
assert (
CliRunner().invoke(cli.cli, create_index_unique_args + [option]).exit_code
== 0
)
def test_create_index_analyze(db_path):
db = Database(db_path)
assert "sqlite_stat1" not in db.table_names()
assert [] == db["Gosh"].indexes
result = CliRunner().invoke(
cli.cli, ["create-index", db_path, "Gosh", "c1", "--analyze"]
)
assert result.exit_code == 0
assert "sqlite_stat1" in db.table_names()
def test_create_index_desc(db_path):
db = Database(db_path)
assert [] == db["Gosh"].indexes
result = CliRunner().invoke(cli.cli, ["create-index", db_path, "Gosh", "--", "-c1"])
assert result.exit_code == 0
assert (
db.execute("select sql from sqlite_master where type='index'").fetchone()[0]
== "CREATE INDEX [idx_Gosh_c1]\n ON [Gosh] ([c1] desc)"
)
@pytest.mark.parametrize(
"col_name,col_type,expected_schema",
(
("text", "TEXT", "CREATE TABLE [dogs] (\n [name] TEXT\n, [text] TEXT)"),
("text", "str", "CREATE TABLE [dogs] (\n [name] TEXT\n, [text] TEXT)"),
("text", "STR", "CREATE TABLE [dogs] (\n [name] TEXT\n, [text] TEXT)"),
(
"integer",
"INTEGER",
"CREATE TABLE [dogs] (\n [name] TEXT\n, [integer] INTEGER)",
),
(
"integer",
"int",
"CREATE TABLE [dogs] (\n [name] TEXT\n, [integer] INTEGER)",
),
("float", "FLOAT", "CREATE TABLE [dogs] (\n [name] TEXT\n, [float] FLOAT)"),
("blob", "blob", "CREATE TABLE [dogs] (\n [name] TEXT\n, [blob] BLOB)"),
("blob", "BLOB", "CREATE TABLE [dogs] (\n [name] TEXT\n, [blob] BLOB)"),
("blob", "bytes", "CREATE TABLE [dogs] (\n [name] TEXT\n, [blob] BLOB)"),
("blob", "BYTES", "CREATE TABLE [dogs] (\n [name] TEXT\n, [blob] BLOB)"),
("default", None, "CREATE TABLE [dogs] (\n [name] TEXT\n, [default] TEXT)"),
),
)
def test_add_column(db_path, col_name, col_type, expected_schema):
db = Database(db_path)
db.create_table("dogs", {"name": str})
assert db["dogs"].schema == "CREATE TABLE [dogs] (\n [name] TEXT\n)"
args = ["add-column", db_path, "dogs", col_name]
if col_type is not None:
args.append(col_type)
assert CliRunner().invoke(cli.cli, args).exit_code == 0
assert db["dogs"].schema == expected_schema
@pytest.mark.parametrize("ignore", (True, False))
def test_add_column_ignore(db_path, ignore):
db = Database(db_path)
db.create_table("dogs", {"name": str})
args = ["add-column", db_path, "dogs", "name"] + (["--ignore"] if ignore else [])
result = CliRunner().invoke(cli.cli, args)
if ignore:
assert result.exit_code == 0
else:
assert result.exit_code == 1
assert result.output == "Error: duplicate column name: name\n"
def test_add_column_not_null_default(db_path):
db = Database(db_path)
db.create_table("dogs", {"name": str})
assert db["dogs"].schema == "CREATE TABLE [dogs] (\n [name] TEXT\n)"
args = [
"add-column",
db_path,
"dogs",
"nickname",
"--not-null-default",
"dogs'dawg",
]
assert CliRunner().invoke(cli.cli, args).exit_code == 0
assert db["dogs"].schema == (
"CREATE TABLE [dogs] (\n"
" [name] TEXT\n"
", [nickname] TEXT NOT NULL DEFAULT 'dogs''dawg')"
)
@pytest.mark.parametrize(
"args,assert_message",
(
(
["books", "author_id", "authors", "id"],
"Explicit other_table and other_column",
),
(["books", "author_id", "authors"], "Explicit other_table, guess other_column"),
(["books", "author_id"], "Automatically guess other_table and other_column"),
),
)
def test_add_foreign_key(db_path, args, assert_message):
db = Database(db_path)
db["authors"].insert_all(
[{"id": 1, "name": "Sally"}, {"id": 2, "name": "Asheesh"}], pk="id"
)
db["books"].insert_all(
[
{"title": "Hedgehogs of the world", "author_id": 1},
{"title": "How to train your wolf", "author_id": 2},
]
)
assert (
CliRunner().invoke(cli.cli, ["add-foreign-key", db_path] + args).exit_code == 0
), assert_message
assert [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
)
] == db["books"].foreign_keys
# Error if we try to add it twice:
result = CliRunner().invoke(
cli.cli, ["add-foreign-key", db_path, "books", "author_id", "authors", "id"]
)
assert result.exit_code != 0
assert (
"Error: Foreign key already exists for author_id => authors.id"
== result.output.strip()
)
# No error if we add it twice with --ignore
result = CliRunner().invoke(
cli.cli,
["add-foreign-key", db_path, "books", "author_id", "authors", "id", "--ignore"],
)
assert result.exit_code == 0
# Error if we try against an invalid column
result = CliRunner().invoke(
cli.cli, ["add-foreign-key", db_path, "books", "author_id", "authors", "bad"]
)
assert result.exit_code != 0
assert "Error: No such column: authors.bad" == result.output.strip()
def test_add_column_foreign_key(db_path):
db = Database(db_path)
db["authors"].insert({"id": 1, "name": "Sally"}, pk="id")
db["books"].insert({"title": "Hedgehogs of the world"})
# Add an author_id foreign key column to the books table
result = CliRunner().invoke(
cli.cli, ["add-column", db_path, "books", "author_id", "--fk", "authors"]
)
assert result.exit_code == 0, result.output
assert db["books"].schema == (
'CREATE TABLE "books" (\n'
" [title] TEXT,\n"
" [author_id] INTEGER REFERENCES [authors]([id])\n"
")"
)
# Try it again with a custom --fk-col
result = CliRunner().invoke(
cli.cli,
[
"add-column",
db_path,
"books",
"author_name_ref",
"--fk",
"authors",
"--fk-col",
"name",
],
)
assert result.exit_code == 0, result.output
assert db["books"].schema == (
'CREATE TABLE "books" (\n'
" [title] TEXT,\n"
" [author_id] INTEGER REFERENCES [authors]([id]),\n"
" [author_name_ref] TEXT REFERENCES [authors]([name])\n"
")"
)
# Throw an error if the --fk table does not exist
result = CliRunner().invoke(
cli.cli, ["add-column", db_path, "books", "author_id", "--fk", "bobcats"]
)
assert result.exit_code != 0
assert "table 'bobcats' does not exist" in str(result.exception)
def test_suggest_alter_if_column_missing(db_path):
db = Database(db_path)
db["authors"].insert({"id": 1, "name": "Sally"}, pk="id")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "authors", "-"],
input='{"id": 2, "name": "Barry", "age": 43}',
)
assert result.exit_code != 0
assert result.output.strip() == (
"Error: table authors has no column named age\n\n"
"Try using --alter to add additional columns"
)
def test_index_foreign_keys(db_path):
test_add_column_foreign_key(db_path)
db = Database(db_path)
assert [] == db["books"].indexes
result = CliRunner().invoke(cli.cli, ["index-foreign-keys", db_path])
assert result.exit_code == 0
assert [["author_id"], ["author_name_ref"]] == [
i.columns for i in db["books"].indexes
]
def test_enable_fts(db_path):
db = Database(db_path)
assert db["Gosh"].detect_fts() is None
result = CliRunner().invoke(
cli.cli, ["enable-fts", db_path, "Gosh", "c1", "--fts4"]
)
assert result.exit_code == 0
assert "Gosh_fts" == db["Gosh"].detect_fts()
# Table names with restricted chars are handled correctly.
# colons and dots are restricted characters for table names.
db["http://example.com"].create({"c1": str, "c2": str, "c3": str})
assert db["http://example.com"].detect_fts() is None
result = CliRunner().invoke(
cli.cli,
[
"enable-fts",
db_path,
"http://example.com",
"c1",
"--fts4",
"--tokenize",
"porter",
],
)
assert result.exit_code == 0
assert "http://example.com_fts" == db["http://example.com"].detect_fts()
# Check tokenize was set to porter
assert (
"CREATE VIRTUAL TABLE [http://example.com_fts] USING FTS4 (\n"
" [c1],\n"
" tokenize='porter',\n"
" content=[http://example.com]"
"\n)"
) == db["http://example.com_fts"].schema
db["http://example.com"].drop()
def test_enable_fts_replace(db_path):
db = Database(db_path)
assert db["Gosh"].detect_fts() is None
result = CliRunner().invoke(
cli.cli, ["enable-fts", db_path, "Gosh", "c1", "--fts4"]
)
assert result.exit_code == 0
assert "Gosh_fts" == db["Gosh"].detect_fts()
assert db["Gosh_fts"].columns_dict == {"c1": str}
# This should throw an error
result2 = CliRunner().invoke(
cli.cli, ["enable-fts", db_path, "Gosh", "c1", "--fts4"]
)
assert result2.exit_code == 1
assert result2.output == "Error: table [Gosh_fts] already exists\n"
# This should work
result3 = CliRunner().invoke(
cli.cli, ["enable-fts", db_path, "Gosh", "c2", "--fts4", "--replace"]
)
assert result3.exit_code == 0
assert db["Gosh_fts"].columns_dict == {"c2": str}
def test_enable_fts_with_triggers(db_path):
Database(db_path)["Gosh"].insert_all([{"c1": "baz"}])
exit_code = (
CliRunner()
.invoke(
cli.cli,
["enable-fts", db_path, "Gosh", "c1", "--fts4", "--create-triggers"],
)
.exit_code
)
assert exit_code == 0
def search(q):
return (
Database(db_path)
.execute("select c1 from Gosh_fts where c1 match ?", [q])
.fetchall()
)
assert [("baz",)] == search("baz")
Database(db_path)["Gosh"].insert_all([{"c1": "martha"}])
assert [("martha",)] == search("martha")
def test_populate_fts(db_path):
Database(db_path)["Gosh"].insert_all([{"c1": "baz"}])
exit_code = (
CliRunner()
.invoke(cli.cli, ["enable-fts", db_path, "Gosh", "c1", "--fts4"])
.exit_code
)
assert exit_code == 0
def search(q):
return (
Database(db_path)
.execute("select c1 from Gosh_fts where c1 match ?", [q])
.fetchall()
)
assert [("baz",)] == search("baz")
Database(db_path)["Gosh"].insert_all([{"c1": "martha"}])
assert [] == search("martha")
exit_code = (
CliRunner().invoke(cli.cli, ["populate-fts", db_path, "Gosh", "c1"]).exit_code
)
assert exit_code == 0
assert [("martha",)] == search("martha")
def test_disable_fts(db_path):
db = Database(db_path)
assert {"Gosh", "Gosh2"} == set(db.table_names())
db["Gosh"].enable_fts(["c1"], create_triggers=True)
assert {
"Gosh_fts",
"Gosh_fts_idx",
"Gosh_fts_data",
"Gosh2",
"Gosh_fts_config",
"Gosh",
"Gosh_fts_docsize",
} == set(db.table_names())
exit_code = CliRunner().invoke(cli.cli, ["disable-fts", db_path, "Gosh"]).exit_code
assert exit_code == 0
assert {"Gosh", "Gosh2"} == set(db.table_names())
def test_vacuum(db_path):
result = CliRunner().invoke(cli.cli, ["vacuum", db_path])
assert result.exit_code == 0
def test_dump(db_path):
result = CliRunner().invoke(cli.cli, ["dump", db_path])
assert result.exit_code == 0
assert result.output.startswith("BEGIN TRANSACTION;")
assert result.output.strip().endswith("COMMIT;")
@pytest.mark.parametrize("tables", ([], ["Gosh"], ["Gosh2"]))
def test_optimize(db_path, tables):
db = Database(db_path)
with db.conn:
for table in ("Gosh", "Gosh2"):
db[table].insert_all(
[
{
"c1": "verb{}".format(i),
"c2": "noun{}".format(i),
"c3": "adjective{}".format(i),
}
for i in range(10000)
]
)
db["Gosh"].enable_fts(["c1", "c2", "c3"], fts_version="FTS4")
db["Gosh2"].enable_fts(["c1", "c2", "c3"], fts_version="FTS5")
size_before_optimize = os.stat(db_path).st_size
result = CliRunner().invoke(cli.cli, ["optimize", db_path] + tables)
assert result.exit_code == 0
size_after_optimize = os.stat(db_path).st_size
# Weirdest thing: tests started failing because size after
# ended up larger than size before in some cases. I think
# it's OK to tolerate that happening, though it's very strange.
assert size_after_optimize <= (size_before_optimize + 10000)
# Soundness check that --no-vacuum doesn't throw errors:
result = CliRunner().invoke(cli.cli, ["optimize", "--no-vacuum", db_path])
assert result.exit_code == 0
def test_rebuild_fts_fixes_docsize_error(db_path):
db = Database(db_path, recursive_triggers=False)
records = [
{
"c1": "verb{}".format(i),
"c2": "noun{}".format(i),
"c3": "adjective{}".format(i),
}
for i in range(10000)
]
with db.conn:
db["fts5_table"].insert_all(records, pk="c1")
db["fts5_table"].enable_fts(
["c1", "c2", "c3"], fts_version="FTS5", create_triggers=True
)
# Search should work
assert list(db["fts5_table"].search("verb1"))
# Replicate docsize error from this issue for FTS5
# https://github.com/simonw/sqlite-utils/issues/149
assert db["fts5_table_fts_docsize"].count == 10000
db["fts5_table"].insert_all(records, replace=True)
assert db["fts5_table"].count == 10000
assert db["fts5_table_fts_docsize"].count == 20000
# Running rebuild-fts should fix this
result = CliRunner().invoke(cli.cli, ["rebuild-fts", db_path, "fts5_table"])
assert result.exit_code == 0
assert db["fts5_table_fts_docsize"].count == 10000
@pytest.mark.parametrize(
"format,expected",
[
("--csv", "id,name,age\n1,Cleo,4\n2,Pancakes,2\n"),
("--tsv", "id\tname\tage\n1\tCleo\t4\n2\tPancakes\t2\n"),
],
)
def test_query_csv(db_path, format, expected):
db = Database(db_path)
with db.conn:
db["dogs"].insert_all(
[
{"id": 1, "age": 4, "name": "Cleo"},
{"id": 2, "age": 2, "name": "Pancakes"},
]
)
result = CliRunner().invoke(
cli.cli, [db_path, "select id, name, age from dogs", format]
)
assert result.exit_code == 0
assert result.output.replace("\r", "") == expected
# Test the no-headers option:
result = CliRunner().invoke(
cli.cli, [db_path, "select id, name, age from dogs", "--no-headers", format]
)
expected_rest = "\n".join(expected.split("\n")[1:]).strip()
assert result.output.strip().replace("\r", "") == expected_rest
_all_query = "select id, name, age from dogs"
_one_query = "select id, name, age from dogs where id = 1"
@pytest.mark.parametrize(
"sql,args,expected",
[
(
_all_query,
[],
'[{"id": 1, "name": "Cleo", "age": 4},\n {"id": 2, "name": "Pancakes", "age": 2}]',
),
(
_all_query,
["--nl"],
'{"id": 1, "name": "Cleo", "age": 4}\n{"id": 2, "name": "Pancakes", "age": 2}',
),
(_all_query, ["--arrays"], '[[1, "Cleo", 4],\n [2, "Pancakes", 2]]'),
(_all_query, ["--arrays", "--nl"], '[1, "Cleo", 4]\n[2, "Pancakes", 2]'),
(_one_query, [], '[{"id": 1, "name": "Cleo", "age": 4}]'),
(_one_query, ["--nl"], '{"id": 1, "name": "Cleo", "age": 4}'),
(_one_query, ["--arrays"], '[[1, "Cleo", 4]]'),
(_one_query, ["--arrays", "--nl"], '[1, "Cleo", 4]'),
(
"select id, dog(age) from dogs",
["--functions", "def dog(i):\n return i * 7"],
'[{"id": 1, "dog(age)": 28},\n {"id": 2, "dog(age)": 14}]',
),
],
)
def test_query_json(db_path, sql, args, expected):
db = Database(db_path)
with db.conn:
db["dogs"].insert_all(
[
{"id": 1, "age": 4, "name": "Cleo"},
{"id": 2, "age": 2, "name": "Pancakes"},
]
)
result = CliRunner().invoke(cli.cli, [db_path, sql] + args)
assert expected == result.output.strip()
def test_query_json_empty(db_path):
result = CliRunner().invoke(
cli.cli,
[db_path, "select * from sqlite_master where 0"],
)
assert result.output.strip() == "[]"
def test_query_invalid_function(db_path):
result = CliRunner().invoke(
cli.cli, [db_path, "select bad()", "--functions", "def invalid_python"]
)
assert result.exit_code == 1
assert result.output.startswith("Error: Error in functions definition:")
TEST_FUNCTIONS = """
def zero():
return 0
def one(a):
return a
def _two(a, b):
return a + b
def two(a, b):
return _two(a, b)
"""
def test_query_complex_function(db_path):
result = CliRunner().invoke(
cli.cli,
[
db_path,
"select zero(), one(1), two(1, 2)",
"--functions",
TEST_FUNCTIONS,
],
)
assert result.exit_code == 0
assert json.loads(result.output.strip()) == [
{"zero()": 0, "one(1)": 1, "two(1, 2)": 3}
]
@pytest.mark.skipif(
not _supports_pragma_function_list(),
reason="Needs SQLite version that supports pragma_function_list()",
)
def test_hidden_functions_are_hidden(db_path):
result = CliRunner().invoke(
cli.cli,
[
db_path,
"select name from pragma_function_list()",
"--functions",
TEST_FUNCTIONS,
],
)
assert result.exit_code == 0
functions = {r["name"] for r in json.loads(result.output.strip())}
assert "zero" in functions
assert "one" in functions
assert "two" in functions
assert "_two" not in functions
LOREM_IPSUM_COMPRESSED = (
b"x\x9c\xed\xd1\xcdq\x03!\x0c\x05\xe0\xbb\xabP\x01\x1eW\x91\xdc|M\x01\n\xc8\x8e"
b"f\xf83H\x1e\x97\x1f\x91M\x8e\xe9\xe0\xdd\x96\x05\x84\xf4\xbek\x9fRI\xc7\xf2J"
b"\xb9\x97>i\xa9\x11W\xb13\xa5\xde\x96$\x13\xf3I\x9cu\xe8J\xda\xee$EcsI\x8e\x0b"
b"$\xea\xab\xf6L&u\xc4emI\xb3foFnT\xf83\xca\x93\xd8QZ\xa8\xf2\xbd1q\xd1\x87\xf3"
b"\x85>\x8c\xa4i\x8d\xdaTu\x7f<c\xc9\xf5L\x0f\xd7E\xad/\x9b\x9eI^2\x93\x1a\x9b"
b"\xf6F^\n\xd7\xd4\x8f\xca\xfb\x90.\xdd/\xfd\x94\xd4\x11\x87I8\x1a\xaf\xd1S?\x06"
b"\x88\xa7\xecBo\xbb$\xbb\t\xe9\xf4\xe8\xe4\x98U\x1bM\x19S\xbe\xa4e\x991x\xfc"
b"x\xf6\xe2#\x9e\x93h'&%YK(i)\x7f\t\xc5@N7\xbf+\x1b\xb5\xdd\x10\r\x9e\xb1\xf0"
b"y\xa1\xf7W\x92a\xe2;\xc6\xc8\xa0\xa7\xc4\x92\xe2\\\xf2\xa1\x99m\xdf\x88)\xc6"
b"\xec\x9a\xa5\xed\x14wR\xf1h\xf22x\xcfM\xfdv\xd3\xa4LY\x96\xcc\xbd[{\xd9m\xf0"
b"\x0eH#\x8e\xf5\x9b\xab\xd7\xcb\xe9t\x05\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03"
b"\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03"
b"\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03"
b"\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\x03"
b"\x1f\xf8\xc0\x07>\xf0\x81\x0f|\xe0\xfb\x8f\xef\x1b\x9b\x06\x83}"
)
def test_query_json_binary(db_path):
db = Database(db_path)
with db.conn:
db["files"].insert(
{
"name": "lorem.txt",
"sz": 16984,
"data": LOREM_IPSUM_COMPRESSED,
},
pk="name",
)
result = CliRunner().invoke(cli.cli, [db_path, "select name, sz, data from files"])
assert result.exit_code == 0, str(result)
assert json.loads(result.output.strip()) == [
{
"name": "lorem.txt",
"sz": 16984,
"data": {
"$base64": True,
"encoded": (
(
"eJzt0c1xAyEMBeC7q1ABHleR3HxNAQrIjmb4M0gelx+RTY7p4N2WBYT0vmufUknH"
"8kq5lz5pqRFXsTOl3pYkE/NJnHXoStruJEVjc0mOCyTqq/ZMJnXEZW1Js2ZvRm5U+"
"DPKk9hRWqjyvTFx0YfzhT6MpGmN2lR1fzxjyfVMD9dFrS+bnkleMpMam/ZGXgrX1I"
"/K+5Au3S/9lNQRh0k4Gq/RUz8GiKfsQm+7JLsJ6fTo5JhVG00ZU76kZZkxePx49uI"
"jnpNoJyYlWUsoaSl/CcVATje/Kxu13RANnrHweaH3V5Jh4jvGyKCnxJLiXPKhmW3f"
"iCnG7Jql7RR3UvFo8jJ4z039dtOkTFmWzL1be9lt8A5II471m6vXy+l0BR/4wAc+8"
"IEPfOADH/jABz7wgQ984AMf+MAHPvCBD3zgAx/4wAc+8IEPfOADH/jABz7wgQ984A"
"Mf+MAHPvCBD3zgAx/4wAc+8IEPfOADH/jABz7wgQ984PuP7xubBoN9"
)
),
},
}
]
@pytest.mark.parametrize(
"sql,params,expected",
[
("select 1 + 1 as out", {"p": "2"}, 2),
("select 1 + :p as out", {"p": "2"}, 3),
(
"select :hello as out",
{"hello": """This"has'many'quote"s"""},
"""This"has'many'quote"s""",
),
],
)
def test_query_params(db_path, sql, params, expected):
extra_args = []
for key, value in params.items():
extra_args.extend(["-p", key, value])
result = CliRunner().invoke(cli.cli, [db_path, sql] + extra_args)
assert result.exit_code == 0, str(result)
assert json.loads(result.output.strip()) == [{"out": expected}]
def test_query_json_with_json_cols(db_path):
db = Database(db_path)
with db.conn:
db["dogs"].insert(
{
"id": 1,
"name": "Cleo",
"friends": [{"name": "Pancakes"}, {"name": "Bailey"}],
}
)
result = CliRunner().invoke(
cli.cli, [db_path, "select id, name, friends from dogs"]
)
assert (
r"""
[{"id": 1, "name": "Cleo", "friends": "[{\"name\": \"Pancakes\"}, {\"name\": \"Bailey\"}]"}]
""".strip()
== result.output.strip()
)
# With --json-cols:
result = CliRunner().invoke(
cli.cli, [db_path, "select id, name, friends from dogs", "--json-cols"]
)
expected = r"""
[{"id": 1, "name": "Cleo", "friends": [{"name": "Pancakes"}, {"name": "Bailey"}]}]
""".strip()
assert expected == result.output.strip()
# Test rows command too
result_rows = CliRunner().invoke(cli.cli, ["rows", db_path, "dogs", "--json-cols"])
assert expected == result_rows.output.strip()
@pytest.mark.parametrize(
"content,is_binary",
[(b"\x00\x0fbinary", True), ("this is text", False), (1, False), (1.5, False)],
)
def test_query_raw(db_path, content, is_binary):
Database(db_path)["files"].insert({"content": content})
result = CliRunner().invoke(
cli.cli, [db_path, "select content from files", "--raw"]
)
if is_binary:
assert result.stdout_bytes == content
else:
assert result.output == str(content)
@pytest.mark.parametrize(
"content,is_binary",
[(b"\x00\x0fbinary", True), ("this is text", False), (1, False), (1.5, False)],
)
def test_query_raw_lines(db_path, content, is_binary):
Database(db_path)["files"].insert_all({"content": content} for _ in range(3))
result = CliRunner().invoke(
cli.cli, [db_path, "select content from files", "--raw-lines"]
)
if is_binary:
assert result.stdout_bytes == b"\n".join(content for _ in range(3)) + b"\n"
else:
assert result.output == "\n".join(str(content) for _ in range(3)) + "\n"
def test_query_memory_does_not_create_file(tmpdir):
owd = os.getcwd()
try:
os.chdir(tmpdir)
# This should create a foo.db file
CliRunner().invoke(cli.cli, ["foo.db", "select sqlite_version()"])
# This should NOT create a file
result = CliRunner().invoke(cli.cli, [":memory:", "select sqlite_version()"])
assert ["sqlite_version()"] == list(json.loads(result.output)[0].keys())
finally:
os.chdir(owd)
assert ["foo.db"] == os.listdir(tmpdir)
@pytest.mark.parametrize(
"args,expected",
[
(
[],
'[{"id": 1, "name": "Cleo", "age": 4},\n {"id": 2, "name": "Pancakes", "age": 2}]',
),
(
["--nl"],
'{"id": 1, "name": "Cleo", "age": 4}\n{"id": 2, "name": "Pancakes", "age": 2}',
),
(["--arrays"], '[[1, "Cleo", 4],\n [2, "Pancakes", 2]]'),
(["--arrays", "--nl"], '[1, "Cleo", 4]\n[2, "Pancakes", 2]'),
(
["--nl", "-c", "age", "-c", "name"],
'{"age": 4, "name": "Cleo"}\n{"age": 2, "name": "Pancakes"}',
),
# --limit and --offset
(
["-c", "name", "--limit", "1"],
'[{"name": "Cleo"}]',
),
(
["-c", "name", "--limit", "1", "--offset", "1"],
'[{"name": "Pancakes"}]',
),
# --where
(
["-c", "name", "--where", "id = 1"],
'[{"name": "Cleo"}]',
),
(
["-c", "name", "--where", "id = :id", "-p", "id", "1"],
'[{"name": "Cleo"}]',
),
(
["-c", "name", "--where", "id = :id", "--param", "id", "1"],
'[{"name": "Cleo"}]',
),
# --order
(
["-c", "id", "--order", "id desc", "--limit", "1"],
'[{"id": 2}]',
),
(
["-c", "id", "--order", "id", "--limit", "1"],
'[{"id": 1}]',
),
],
)
def test_rows(db_path, args, expected):
db = Database(db_path)
with db.conn:
db["dogs"].insert_all(
[
{"id": 1, "age": 4, "name": "Cleo"},
{"id": 2, "age": 2, "name": "Pancakes"},
],
column_order=("id", "name", "age"),
)
result = CliRunner().invoke(cli.cli, ["rows", db_path, "dogs"] + args)
assert expected == result.output.strip()
def test_upsert(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
db = Database(db_path)
insert_dogs = [
{"id": 1, "name": "Cleo", "age": 4},
{"id": 2, "name": "Nixie", "age": 4},
]
write_json(json_path, insert_dogs)
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "dogs", json_path, "--pk", "id"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert 2 == db["dogs"].count
# Now run the upsert to update just their ages
upsert_dogs = [
{"id": 1, "age": 5},
{"id": 2, "age": 5},
]
write_json(json_path, upsert_dogs)
result = CliRunner().invoke(
cli.cli,
["upsert", db_path, "dogs", json_path, "--pk", "id"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert list(db.query("select * from dogs order by id")) == [
{"id": 1, "name": "Cleo", "age": 5},
{"id": 2, "name": "Nixie", "age": 5},
]
def test_upsert_pk_required(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
insert_dogs = [
{"id": 1, "name": "Cleo", "age": 4},
{"id": 2, "name": "Nixie", "age": 4},
]
write_json(json_path, insert_dogs)
result = CliRunner().invoke(
cli.cli,
["upsert", db_path, "dogs", json_path],
catch_exceptions=False,
)
assert result.exit_code == 2
assert "Error: Missing option '--pk'" in result.output
def test_upsert_analyze(db_path, tmpdir):
db = Database(db_path)
db["rows"].insert({"id": 1, "foo": "x", "n": 3}, pk="id")
db["rows"].create_index(["n"])
assert "sqlite_stat1" not in db.table_names()
result = CliRunner().invoke(
cli.cli,
["upsert", db_path, "rows", "-", "--nl", "--analyze", "--pk", "id"],
input='{"id": 2, "foo": "bar", "n": 1}',
)
assert result.exit_code == 0, result.output
assert "sqlite_stat1" in db.table_names()
def test_upsert_flatten(tmpdir):
db_path = str(tmpdir / "flat.db")
db = Database(db_path)
db["upsert_me"].insert({"id": 1, "name": "Example"}, pk="id")
result = CliRunner().invoke(
cli.cli,
["upsert", db_path, "upsert_me", "-", "--flatten", "--pk", "id", "--alter"],
input=json.dumps({"id": 1, "nested": {"two": 2}}),
)
assert result.exit_code == 0
assert list(db.query("select * from upsert_me")) == [
{"id": 1, "name": "Example", "nested_two": 2}
]
def test_upsert_alter(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
db = Database(db_path)
insert_dogs = [{"id": 1, "name": "Cleo"}]
write_json(json_path, insert_dogs)
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id"]
)
assert result.exit_code == 0, result.output
# Should fail with error code if no --alter
upsert_dogs = [{"id": 1, "age": 5}]
write_json(json_path, upsert_dogs)
result = CliRunner().invoke(
cli.cli, ["upsert", db_path, "dogs", json_path, "--pk", "id"]
)
assert result.exit_code == 1
assert (
"Error: table dogs has no column named age\n\n"
"Try using --alter to add additional columns"
) == result.output.strip()
# Should succeed with --alter
result = CliRunner().invoke(
cli.cli, ["upsert", db_path, "dogs", json_path, "--pk", "id", "--alter"]
)
assert result.exit_code == 0
assert list(db.query("select * from dogs order by id")) == [
{"id": 1, "name": "Cleo", "age": 5},
]
@pytest.mark.parametrize(
"args,schema",
[
# No primary key
(
[
"name",
"text",
"age",
"integer",
],
("CREATE TABLE [t] (\n [name] TEXT,\n [age] INTEGER\n)"),
),
# All types:
(
[
"id",
"integer",
"name",
"text",
"age",
"integer",
"weight",
"float",
"thumbnail",
"blob",
"--pk",
"id",
],
(
"CREATE TABLE [t] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [age] INTEGER,\n"
" [weight] FLOAT,\n"
" [thumbnail] BLOB\n"
")"
),
),
# Not null:
(
["name", "text", "--not-null", "name"],
("CREATE TABLE [t] (\n" " [name] TEXT NOT NULL\n" ")"),
),
# Default:
(
["age", "integer", "--default", "age", "3"],
("CREATE TABLE [t] (\n" " [age] INTEGER DEFAULT '3'\n" ")"),
),
# Compound primary key
(
["category", "text", "name", "text", "--pk", "category", "--pk", "name"],
(
"CREATE TABLE [t] (\n [category] TEXT,\n [name] TEXT,\n"
" PRIMARY KEY ([category], [name])\n)"
),
),
],
)
def test_create_table(args, schema):
runner = CliRunner()
with runner.isolated_filesystem():
result = runner.invoke(
cli.cli,
[
"create-table",
"test.db",
"t",
]
+ args,
catch_exceptions=False,
)
assert result.exit_code == 0
db = Database("test.db")
assert schema == db["t"].schema
def test_create_table_foreign_key():
runner = CliRunner()
creates = (
["authors", "id", "integer", "name", "text", "--pk", "id"],
[
"books",
"id",
"integer",
"title",
"text",
"author_id",
"integer",
"--pk",
"id",
"--fk",
"author_id",
"authors",
"id",
],
)
with runner.isolated_filesystem():
for args in creates:
result = runner.invoke(
cli.cli, ["create-table", "books.db"] + args, catch_exceptions=False
)
assert result.exit_code == 0
db = Database("books.db")
assert (
"CREATE TABLE [authors] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT\n"
")"
) == db["authors"].schema
assert (
"CREATE TABLE [books] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [title] TEXT,\n"
" [author_id] INTEGER REFERENCES [authors]([id])\n"
")"
) == db["books"].schema
def test_create_table_error_if_table_exists():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["dogs"].insert({"name": "Cleo"})
result = runner.invoke(
cli.cli, ["create-table", "test.db", "dogs", "id", "integer"]
)
assert result.exit_code == 1
assert (
'Error: Table "dogs" already exists. Use --replace to delete and replace it.'
== result.output.strip()
)
def test_create_table_ignore():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["dogs"].insert({"name": "Cleo"})
result = runner.invoke(
cli.cli, ["create-table", "test.db", "dogs", "id", "integer", "--ignore"]
)
assert result.exit_code == 0
assert "CREATE TABLE [dogs] (\n [name] TEXT\n)" == db["dogs"].schema
def test_create_table_replace():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["dogs"].insert({"name": "Cleo"})
result = runner.invoke(
cli.cli, ["create-table", "test.db", "dogs", "id", "integer", "--replace"]
)
assert result.exit_code == 0
assert "CREATE TABLE [dogs] (\n [id] INTEGER\n)" == db["dogs"].schema
def test_create_view():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
result = runner.invoke(
cli.cli, ["create-view", "test.db", "version", "select sqlite_version()"]
)
assert result.exit_code == 0
assert "CREATE VIEW version AS select sqlite_version()" == db["version"].schema
def test_create_view_error_if_view_exists():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db.create_view("version", "select sqlite_version() + 1")
result = runner.invoke(
cli.cli, ["create-view", "test.db", "version", "select sqlite_version()"]
)
assert result.exit_code == 1
assert (
'Error: View "version" already exists. Use --replace to delete and replace it.'
== result.output.strip()
)
def test_create_view_ignore():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db.create_view("version", "select sqlite_version() + 1")
result = runner.invoke(
cli.cli,
[
"create-view",
"test.db",
"version",
"select sqlite_version()",
"--ignore",
],
)
assert result.exit_code == 0
assert (
"CREATE VIEW version AS select sqlite_version() + 1" == db["version"].schema
)
def test_create_view_replace():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db.create_view("version", "select sqlite_version() + 1")
result = runner.invoke(
cli.cli,
[
"create-view",
"test.db",
"version",
"select sqlite_version()",
"--replace",
],
)
assert result.exit_code == 0
assert "CREATE VIEW version AS select sqlite_version()" == db["version"].schema
def test_drop_table():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["t"].create({"pk": int}, pk="pk")
assert "t" in db.table_names()
result = runner.invoke(
cli.cli,
[
"drop-table",
"test.db",
"t",
],
)
assert result.exit_code == 0
assert "t" not in db.table_names()
def test_drop_table_error():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["t"].create({"pk": int}, pk="pk")
result = runner.invoke(
cli.cli,
[
"drop-table",
"test.db",
"t2",
],
)
assert result.exit_code == 1
assert 'Error: Table "t2" does not exist' == result.output.strip()
# Using --ignore suppresses that error
result = runner.invoke(
cli.cli,
["drop-table", "test.db", "t2", "--ignore"],
)
assert result.exit_code == 0
def test_drop_view():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db.create_view("hello", "select 1")
assert "hello" in db.view_names()
result = runner.invoke(
cli.cli,
[
"drop-view",
"test.db",
"hello",
],
)
assert result.exit_code == 0
assert "hello" not in db.view_names()
def test_drop_view_error():
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
db["t"].create({"pk": int}, pk="pk")
result = runner.invoke(
cli.cli,
[
"drop-view",
"test.db",
"t2",
],
)
assert result.exit_code == 1
assert 'Error: View "t2" does not exist' == result.output.strip()
# Using --ignore suppresses that error
result = runner.invoke(
cli.cli,
["drop-view", "test.db", "t2", "--ignore"],
)
assert result.exit_code == 0
def test_enable_wal():
runner = CliRunner()
dbs = ["test.db", "test2.db"]
with runner.isolated_filesystem():
for dbname in dbs:
db = Database(dbname)
db["t"].create({"pk": int}, pk="pk")
assert db.journal_mode == "delete"
result = runner.invoke(cli.cli, ["enable-wal"] + dbs, catch_exceptions=False)
assert result.exit_code == 0
for dbname in dbs:
db = Database(dbname)
assert db.journal_mode == "wal"
def test_disable_wal():
runner = CliRunner()
dbs = ["test.db", "test2.db"]
with runner.isolated_filesystem():
for dbname in dbs:
db = Database(dbname)
db["t"].create({"pk": int}, pk="pk")
db.enable_wal()
assert db.journal_mode == "wal"
result = runner.invoke(cli.cli, ["disable-wal"] + dbs)
assert result.exit_code == 0
for dbname in dbs:
db = Database(dbname)
assert db.journal_mode == "delete"
@pytest.mark.parametrize(
"args,expected",
[
(
[],
'[{"rows_affected": 1}]',
),
(["-t"], "rows_affected\n---------------\n 1"),
],
)
def test_query_update(db_path, args, expected):
db = Database(db_path)
with db.conn:
db["dogs"].insert_all(
[
{"id": 1, "age": 4, "name": "Cleo"},
]
)
result = CliRunner().invoke(
cli.cli, [db_path, "update dogs set age = 5 where name = 'Cleo'"] + args
)
assert expected == result.output.strip()
assert list(db.query("select * from dogs")) == [
{"id": 1, "age": 5, "name": "Cleo"},
]
def test_add_foreign_keys(db_path):
db = Database(db_path)
db["countries"].insert({"id": 7, "name": "Panama"}, pk="id")
db["authors"].insert({"id": 3, "name": "Matilda", "country_id": 7}, pk="id")
db["books"].insert({"id": 2, "title": "Wolf anatomy", "author_id": 3}, pk="id")
assert db["authors"].foreign_keys == []
assert db["books"].foreign_keys == []
result = CliRunner().invoke(
cli.cli,
[
"add-foreign-keys",
db_path,
"authors",
"country_id",
"countries",
"id",
"books",
"author_id",
"authors",
"id",
],
)
assert result.exit_code == 0
assert db["authors"].foreign_keys == [
ForeignKey(
table="authors",
column="country_id",
other_table="countries",
other_column="id",
)
]
assert db["books"].foreign_keys == [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
)
]
@pytest.mark.parametrize(
"args,expected_schema",
[
(
[],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT\n"
")"
),
),
(
["--type", "age", "text"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] TEXT NOT NULL DEFAULT '1',\n"
" [name] TEXT\n"
")"
),
),
(
["--drop", "age"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT\n"
")"
),
),
(
["--rename", "age", "age2", "--rename", "id", "pk"],
(
'CREATE TABLE "dogs" (\n'
" [pk] INTEGER PRIMARY KEY,\n"
" [age2] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT\n"
")"
),
),
(
["--not-null", "name"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT NOT NULL\n"
")"
),
),
(
["--not-null-false", "age"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] INTEGER DEFAULT '1',\n"
" [name] TEXT\n"
")"
),
),
(
["--pk", "name"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT PRIMARY KEY\n"
")"
),
),
(
["--pk-none"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT\n"
")"
),
),
(
["--default", "name", "Turnip"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [name] TEXT DEFAULT 'Turnip'\n"
")"
),
),
(
["--default-none", "age"],
(
'CREATE TABLE "dogs" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [age] INTEGER NOT NULL,\n"
" [name] TEXT\n"
")"
),
),
(
["-o", "name", "--column-order", "age", "-o", "id"],
(
'CREATE TABLE "dogs" (\n'
" [name] TEXT,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [id] INTEGER PRIMARY KEY\n"
")"
),
),
],
)
def test_transform(db_path, args, expected_schema):
db = Database(db_path)
with db.conn:
db["dogs"].insert(
{"id": 1, "age": 4, "name": "Cleo"},
not_null={"age"},
defaults={"age": 1},
pk="id",
)
result = CliRunner().invoke(cli.cli, ["transform", db_path, "dogs"] + args)
print(result.output)
assert result.exit_code == 0
schema = db["dogs"].schema
assert schema == expected_schema
@pytest.mark.parametrize(
"extra_args,expected_schema",
(
(
["--drop-foreign-key", "country"],
(
'CREATE TABLE "places" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [country] INTEGER,\n"
" [city] INTEGER REFERENCES [city]([id]),\n"
" [continent] INTEGER\n"
")"
),
),
(
["--drop-foreign-key", "country", "--drop-foreign-key", "city"],
(
'CREATE TABLE "places" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [country] INTEGER,\n"
" [city] INTEGER,\n"
" [continent] INTEGER\n"
")"
),
),
(
["--add-foreign-key", "continent", "continent", "id"],
(
'CREATE TABLE "places" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [country] INTEGER REFERENCES [country]([id]),\n"
" [city] INTEGER REFERENCES [city]([id]),\n"
" [continent] INTEGER REFERENCES [continent]([id])\n"
")"
),
),
),
)
def test_transform_add_or_drop_foreign_key(db_path, extra_args, expected_schema):
db = Database(db_path)
with db.conn:
# Create table with three foreign keys so we can drop two of them
db["continent"].insert({"id": 1, "name": "Europe"}, pk="id")
db["country"].insert({"id": 1, "name": "France"}, pk="id")
db["city"].insert({"id": 24, "name": "Paris"}, pk="id")
db["places"].insert(
{
"id": 32,
"name": "Caveau de la Huchette",
"country": 1,
"city": 24,
"continent": 1,
},
foreign_keys=("country", "city"),
pk="id",
)
result = CliRunner().invoke(
cli.cli,
[
"transform",
db_path,
"places",
]
+ extra_args,
)
assert result.exit_code == 0
schema = db["places"].schema
assert schema == expected_schema
_common_other_schema = (
"CREATE TABLE [species] (\n [id] INTEGER PRIMARY KEY,\n [species] TEXT\n)"
)
@pytest.mark.parametrize(
"args,expected_table_schema,expected_other_schema",
[
(
[],
(
'CREATE TABLE "trees" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [address] TEXT,\n"
" [species_id] INTEGER REFERENCES [species]([id])\n"
")"
),
_common_other_schema,
),
(
["--table", "custom_table"],
(
'CREATE TABLE "trees" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [address] TEXT,\n"
" [custom_table_id] INTEGER REFERENCES [custom_table]([id])\n"
")"
),
"CREATE TABLE [custom_table] (\n [id] INTEGER PRIMARY KEY,\n [species] TEXT\n)",
),
(
["--fk-column", "custom_fk"],
(
'CREATE TABLE "trees" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [address] TEXT,\n"
" [custom_fk] INTEGER REFERENCES [species]([id])\n"
")"
),
_common_other_schema,
),
(
["--rename", "name", "name2"],
'CREATE TABLE "trees" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [address] TEXT,\n"
" [species_id] INTEGER REFERENCES [species]([id])\n"
")",
"CREATE TABLE [species] (\n [id] INTEGER PRIMARY KEY,\n [species] TEXT\n)",
),
],
)
def test_extract(db_path, args, expected_table_schema, expected_other_schema):
db = Database(db_path)
with db.conn:
db["trees"].insert(
{"id": 1, "address": "4 Park Ave", "species": "Palm"},
pk="id",
)
result = CliRunner().invoke(
cli.cli, ["extract", db_path, "trees", "species"] + args
)
print(result.output)
assert result.exit_code == 0
schema = db["trees"].schema
assert schema == expected_table_schema
other_schema = [t for t in db.tables if t.name not in ("trees", "Gosh", "Gosh2")][
0
].schema
assert other_schema == expected_other_schema
def test_insert_encoding(tmpdir):
db_path = str(tmpdir / "test.db")
latin1_csv = (
b"date,name,latitude,longitude\n"
b"2020-01-01,Barra da Lagoa,-27.574,-48.422\n"
b"2020-03-04,S\xe3o Paulo,-23.561,-46.645\n"
b"2020-04-05,Salta,-24.793:-65.408"
)
assert latin1_csv.decode("latin-1").split("\n")[2].split(",")[1] == "São Paulo"
csv_path = str(tmpdir / "test.csv")
with open(csv_path, "wb") as fp:
fp.write(latin1_csv)
# First attempt should error:
bad_result = CliRunner().invoke(
cli.cli,
["insert", db_path, "places", csv_path, "--csv"],
catch_exceptions=False,
)
assert bad_result.exit_code == 1
assert (
"The input you provided uses a character encoding other than utf-8"
in bad_result.output
)
# Using --encoding=latin-1 should work
good_result = CliRunner().invoke(
cli.cli,
["insert", db_path, "places", csv_path, "--encoding", "latin-1", "--csv"],
catch_exceptions=False,
)
assert good_result.exit_code == 0
db = Database(db_path)
assert list(db["places"].rows) == [
{
"date": "2020-01-01",
"name": "Barra da Lagoa",
"latitude": "-27.574",
"longitude": "-48.422",
},
{
"date": "2020-03-04",
"name": "São Paulo",
"latitude": "-23.561",
"longitude": "-46.645",
},
{
"date": "2020-04-05",
"name": "Salta",
"latitude": "-24.793:-65.408",
"longitude": None,
},
]
@pytest.mark.parametrize("fts", ["FTS4", "FTS5"])
@pytest.mark.parametrize(
"extra_arg,expected",
[
(
None,
'[{"rowid": 2, "id": 2, "title": "Title the second"}]\n',
),
("--csv", "rowid,id,title\n2,2,Title the second\n"),
],
)
def test_search(tmpdir, fts, extra_arg, expected):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["articles"].insert_all(
[
{"id": 1, "title": "Title the first"},
{"id": 2, "title": "Title the second"},
{"id": 3, "title": "Title the third"},
],
pk="id",
)
db["articles"].enable_fts(["title"], fts_version=fts)
result = CliRunner().invoke(
cli.cli,
["search", db_path, "articles", "second"] + ([extra_arg] if extra_arg else []),
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output.replace("\r", "") == expected
def test_search_quote(tmpdir):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["creatures"].insert({"name": "dog."}).enable_fts(["name"])
# Without --quote should return an error
error_result = CliRunner().invoke(cli.cli, ["search", db_path, "creatures", 'dog"'])
assert error_result.exit_code == 1
assert error_result.output == (
"Error: unterminated string\n\n"
"Try running this again with the --quote option\n"
)
# With --quote it should work
result = CliRunner().invoke(
cli.cli, ["search", db_path, "creatures", 'dog"', "--quote"]
)
assert result.exit_code == 0
assert result.output.strip() == '[{"rowid": 1, "name": "dog."}]'
def test_indexes(tmpdir):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db.conn.executescript(
"""
create table Gosh (c1 text, c2 text, c3 text);
create index Gosh_idx on Gosh(c2, c3 desc);
"""
)
result = CliRunner().invoke(
cli.cli,
["indexes", str(db_path)],
catch_exceptions=False,
)
assert result.exit_code == 0
assert json.loads(result.output) == [
{
"table": "Gosh",
"index_name": "Gosh_idx",
"seqno": 0,
"cid": 1,
"name": "c2",
"desc": 0,
"coll": "BINARY",
"key": 1,
},
{
"table": "Gosh",
"index_name": "Gosh_idx",
"seqno": 1,
"cid": 2,
"name": "c3",
"desc": 1,
"coll": "BINARY",
"key": 1,
},
]
result2 = CliRunner().invoke(
cli.cli,
["indexes", str(db_path), "--aux"],
catch_exceptions=False,
)
assert result2.exit_code == 0
assert json.loads(result2.output) == [
{
"table": "Gosh",
"index_name": "Gosh_idx",
"seqno": 0,
"cid": 1,
"name": "c2",
"desc": 0,
"coll": "BINARY",
"key": 1,
},
{
"table": "Gosh",
"index_name": "Gosh_idx",
"seqno": 1,
"cid": 2,
"name": "c3",
"desc": 1,
"coll": "BINARY",
"key": 1,
},
{
"table": "Gosh",
"index_name": "Gosh_idx",
"seqno": 2,
"cid": -1,
"name": None,
"desc": 0,
"coll": "BINARY",
"key": 0,
},
]
_TRIGGERS_EXPECTED = (
'[{"name": "blah", "table": "articles", "sql": "CREATE TRIGGER blah '
'AFTER INSERT ON articles\\nBEGIN\\n UPDATE counter SET count = count + 1;\\nEND"}]\n'
)
@pytest.mark.parametrize(
"extra_args,expected",
[
([], _TRIGGERS_EXPECTED),
(["articles"], _TRIGGERS_EXPECTED),
(["counter"], "[]\n"),
],
)
def test_triggers(tmpdir, extra_args, expected):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["articles"].insert(
{"id": 1, "title": "Title the first"},
pk="id",
)
db["counter"].insert({"count": 1})
db.conn.execute(
textwrap.dedent(
"""
CREATE TRIGGER blah AFTER INSERT ON articles
BEGIN
UPDATE counter SET count = count + 1;
END
"""
)
)
args = ["triggers", db_path]
if extra_args:
args.extend(extra_args)
result = CliRunner().invoke(
cli.cli,
args,
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == expected
@pytest.mark.parametrize(
"options,expected",
(
(
[],
(
"CREATE TABLE [dogs] (\n"
" [id] INTEGER,\n"
" [name] TEXT\n"
");\n"
"CREATE TABLE [chickens] (\n"
" [id] INTEGER,\n"
" [name] TEXT,\n"
" [breed] TEXT\n"
");\n"
"CREATE INDEX [idx_chickens_breed]\n"
" ON [chickens] ([breed]);\n"
),
),
(
["dogs"],
("CREATE TABLE [dogs] (\n" " [id] INTEGER,\n" " [name] TEXT\n" ")\n"),
),
(
["chickens", "dogs"],
(
"CREATE TABLE [chickens] (\n"
" [id] INTEGER,\n"
" [name] TEXT,\n"
" [breed] TEXT\n"
")\n"
"CREATE TABLE [dogs] (\n"
" [id] INTEGER,\n"
" [name] TEXT\n"
")\n"
),
),
),
)
def test_schema(tmpdir, options, expected):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["dogs"].create({"id": int, "name": str})
db["chickens"].create({"id": int, "name": str, "breed": str})
db["chickens"].create_index(["breed"])
result = CliRunner().invoke(
cli.cli,
["schema", db_path] + options,
catch_exceptions=False,
)
assert result.exit_code == 0
assert result.output == expected
def test_long_csv_column_value(tmpdir):
db_path = str(tmpdir / "test.db")
csv_path = str(tmpdir / "test.csv")
with open(csv_path, "w") as csv_file:
long_string = "a" * 131073
csv_file.write("id,text\n")
csv_file.write("1,{}\n".format(long_string))
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "bigtable", csv_path, "--csv"],
catch_exceptions=False,
)
assert result.exit_code == 0
db = Database(db_path)
rows = list(db["bigtable"].rows)
assert len(rows) == 1
assert rows[0]["text"] == long_string
@pytest.mark.parametrize(
"args,tsv",
(
(["--csv", "--no-headers"], False),
(["--no-headers"], False),
(["--tsv", "--no-headers"], True),
),
)
def test_import_no_headers(tmpdir, args, tsv):
db_path = str(tmpdir / "test.db")
csv_path = str(tmpdir / "test.csv")
with open(csv_path, "w") as csv_file:
sep = "\t" if tsv else ","
csv_file.write("Cleo{sep}Dog{sep}5\n".format(sep=sep))
csv_file.write("Tracy{sep}Spider{sep}7\n".format(sep=sep))
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "creatures", csv_path] + args,
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
db = Database(db_path)
schema = db["creatures"].schema
assert schema == (
"CREATE TABLE [creatures] (\n"
" [untitled_1] TEXT,\n"
" [untitled_2] TEXT,\n"
" [untitled_3] TEXT\n"
")"
)
rows = list(db["creatures"].rows)
assert rows == [
{"untitled_1": "Cleo", "untitled_2": "Dog", "untitled_3": "5"},
{"untitled_1": "Tracy", "untitled_2": "Spider", "untitled_3": "7"},
]
def test_attach(tmpdir):
foo_path = str(tmpdir / "foo.db")
bar_path = str(tmpdir / "bar.db")
db = Database(foo_path)
with db.conn:
db["foo"].insert({"id": 1, "text": "foo"})
db2 = Database(bar_path)
with db2.conn:
db2["bar"].insert({"id": 1, "text": "bar"})
db.attach("bar", bar_path)
sql = "select * from foo union all select * from bar.bar"
result = CliRunner().invoke(
cli.cli,
[foo_path, "--attach", "bar", bar_path, sql],
catch_exceptions=False,
)
assert json.loads(result.output) == [
{"id": 1, "text": "foo"},
{"id": 1, "text": "bar"},
]
def test_csv_insert_bom(tmpdir):
db_path = str(tmpdir / "test.db")
bom_csv_path = str(tmpdir / "bom.csv")
with open(bom_csv_path, "wb") as fp:
fp.write(b"\xef\xbb\xbfname,age\nCleo,5")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "broken", bom_csv_path, "--encoding", "utf-8", "--csv"],
catch_exceptions=False,
)
assert result.exit_code == 0
result2 = CliRunner().invoke(
cli.cli,
["insert", db_path, "fixed", bom_csv_path, "--csv"],
catch_exceptions=False,
)
assert result2.exit_code == 0
db = Database(db_path)
tables = db.execute("select name, sql from sqlite_master").fetchall()
assert tables == [
("broken", "CREATE TABLE [broken] (\n [\ufeffname] TEXT,\n [age] TEXT\n)"),
("fixed", "CREATE TABLE [fixed] (\n [name] TEXT,\n [age] TEXT\n)"),
]
@pytest.mark.parametrize("option_or_env_var", (None, "-d", "--detect-types"))
def test_insert_detect_types(tmpdir, option_or_env_var):
db_path = str(tmpdir / "test.db")
data = "name,age,weight\nCleo,6,45.5\nDori,1,3.5"
extra = []
if option_or_env_var:
extra = [option_or_env_var]
def _test():
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "creatures", "-", "--csv"] + extra,
catch_exceptions=False,
input=data,
)
assert result.exit_code == 0
db = Database(db_path)
assert list(db["creatures"].rows) == [
{"name": "Cleo", "age": 6, "weight": 45.5},
{"name": "Dori", "age": 1, "weight": 3.5},
]
if option_or_env_var is None:
# Use environment variable instead of option
with mock.patch.dict(os.environ, {"SQLITE_UTILS_DETECT_TYPES": "1"}):
_test()
else:
_test()
@pytest.mark.parametrize("option", ("-d", "--detect-types"))
def test_upsert_detect_types(tmpdir, option):
db_path = str(tmpdir / "test.db")
data = "id,name,age,weight\n1,Cleo,6,45.5\n2,Dori,1,3.5"
result = CliRunner().invoke(
cli.cli,
["upsert", db_path, "creatures", "-", "--csv", "--pk", "id"] + [option],
catch_exceptions=False,
input=data,
)
assert result.exit_code == 0
db = Database(db_path)
assert list(db["creatures"].rows) == [
{"id": 1, "name": "Cleo", "age": 6, "weight": 45.5},
{"id": 2, "name": "Dori", "age": 1, "weight": 3.5},
]
def test_integer_overflow_error(tmpdir):
db_path = str(tmpdir / "test.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "items", "-"],
input=json.dumps({"bignumber": 34223049823094832094802398430298048240}),
)
assert result.exit_code == 1
assert result.output == (
"Error: Python int too large to convert to SQLite INTEGER\n\n"
"sql = INSERT INTO [items] ([bignumber]) VALUES (?)\n"
"parameters = [34223049823094832094802398430298048240]\n"
)
def test_python_dash_m():
"Tool can be run using python -m sqlite_utils"
result = subprocess.run(
[sys.executable, "-m", "sqlite_utils", "--help"], stdout=subprocess.PIPE
)
assert result.returncode == 0
assert b"Commands for interacting with a SQLite database" in result.stdout
@pytest.mark.parametrize("enable_wal", (False, True))
def test_create_database(tmpdir, enable_wal):
db_path = tmpdir / "test.db"
assert not db_path.exists()
args = ["create-database", str(db_path)]
if enable_wal:
args.append("--enable-wal")
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
assert db_path.exists()
assert db_path.read_binary()[:16] == b"SQLite format 3\x00"
db = Database(str(db_path))
if enable_wal:
assert db.journal_mode == "wal"
else:
assert db.journal_mode == "delete"
@pytest.mark.parametrize(
"options,expected",
(
(
[],
[
{"tbl": "two_indexes", "idx": "idx_two_indexes_species", "stat": "1 1"},
{"tbl": "two_indexes", "idx": "idx_two_indexes_name", "stat": "1 1"},
{"tbl": "one_index", "idx": "idx_one_index_name", "stat": "1 1"},
],
),
(
["one_index"],
[
{"tbl": "one_index", "idx": "idx_one_index_name", "stat": "1 1"},
],
),
(
["idx_two_indexes_name"],
[
{"tbl": "two_indexes", "idx": "idx_two_indexes_name", "stat": "1 1"},
],
),
),
)
def test_analyze(tmpdir, options, expected):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["one_index"].insert({"id": 1, "name": "Cleo"}, pk="id")
db["one_index"].create_index(["name"])
db["two_indexes"].insert({"id": 1, "name": "Cleo", "species": "dog"}, pk="id")
db["two_indexes"].create_index(["name"])
db["two_indexes"].create_index(["species"])
result = CliRunner().invoke(cli.cli, ["analyze", db_path] + options)
assert result.exit_code == 0
assert list(db["sqlite_stat1"].rows) == expected
def test_rename_table(tmpdir):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["one"].insert({"id": 1, "name": "Cleo"}, pk="id")
# First try a non-existent table
result_error = CliRunner().invoke(
cli.cli,
["rename-table", db_path, "missing", "two"],
catch_exceptions=False,
)
assert result_error.exit_code == 1
assert result_error.output == (
'Error: Table "missing" could not be renamed. ' "no such table: missing\n"
)
# And check --ignore works
result_error2 = CliRunner().invoke(
cli.cli,
["rename-table", db_path, "missing", "two", "--ignore"],
catch_exceptions=False,
)
assert result_error2.exit_code == 0
previous_columns = db["one"].columns_dict
# Now try for a table that exists
result = CliRunner().invoke(
cli.cli,
["rename-table", db_path, "one", "two"],
catch_exceptions=False,
)
assert result.exit_code == 0
assert db["two"].columns_dict == previous_columns
def test_duplicate_table(tmpdir):
db_path = str(tmpdir / "test.db")
db = Database(db_path)
db["one"].insert({"id": 1, "name": "Cleo"}, pk="id")
# First try a non-existent table
result_error = CliRunner().invoke(
cli.cli,
["duplicate", db_path, "missing", "two"],
catch_exceptions=False,
)
assert result_error.exit_code == 1
assert result_error.output == 'Error: Table "missing" does not exist\n'
# And check --ignore works
result_error2 = CliRunner().invoke(
cli.cli,
["duplicate", db_path, "missing", "two", "--ignore"],
catch_exceptions=False,
)
assert result_error2.exit_code == 0
# Now try for a table that exists
result = CliRunner().invoke(
cli.cli,
["duplicate", db_path, "one", "two"],
catch_exceptions=False,
)
assert result.exit_code == 0
assert db["one"].columns_dict == db["two"].columns_dict
assert list(db["one"].rows) == list(db["two"].rows)
@pytest.mark.skipif(not _has_compiled_ext(), reason="Requires compiled ext.c")
@pytest.mark.parametrize(
"entrypoint,should_pass,should_fail",
(
(None, ("a",), ("b", "c")),
("sqlite3_ext_b_init", ("b"), ("a", "c")),
("sqlite3_ext_c_init", ("c"), ("a", "b")),
),
)
def test_load_extension(entrypoint, should_pass, should_fail):
ext = COMPILED_EXTENSION_PATH
if entrypoint:
ext += ":" + entrypoint
for func in should_pass:
result = CliRunner().invoke(
cli.cli,
["memory", "select {}()".format(func), "--load-extension", ext],
catch_exceptions=False,
)
assert result.exit_code == 0
for func in should_fail:
result = CliRunner().invoke(
cli.cli,
["memory", "select {}()".format(func), "--load-extension", ext],
catch_exceptions=False,
)
assert result.exit_code == 1
@pytest.mark.parametrize("strict", (False, True))
def test_create_table_strict(strict):
runner = CliRunner()
with runner.isolated_filesystem():
db = Database("test.db")
result = runner.invoke(
cli.cli,
["create-table", "test.db", "items", "id", "integer", "w", "float"]
+ (["--strict"] if strict else []),
)
assert result.exit_code == 0
assert db["items"].strict == strict or not db.supports_strict
# Should have a floating point column
assert db["items"].columns_dict == {"id": int, "w": float}
@pytest.mark.parametrize("method", ("insert", "upsert"))
@pytest.mark.parametrize("strict", (False, True))
def test_insert_upsert_strict(tmpdir, method, strict):
db_path = str(tmpdir / "test.db")
result = CliRunner().invoke(
cli.cli,
[method, db_path, "items", "-", "--csv", "--pk", "id"]
+ (["--strict"] if strict else []),
input="id\n1",
)
assert result.exit_code == 0
db = Database(db_path)
assert db["items"].strict == strict or not db.supports_strict
</document_content>
</document>
<document index="18">
<source>./tests/test_cli_bulk.py</source>
<document_content>
from click.testing import CliRunner
from sqlite_utils import cli, Database
import pathlib
import pytest
import subprocess
import sys
import time
@pytest.fixture
def test_db_and_path(tmpdir):
db_path = str(pathlib.Path(tmpdir) / "data.db")
db = Database(db_path)
db["example"].insert_all(
[
{"id": 1, "name": "One"},
{"id": 2, "name": "Two"},
],
pk="id",
)
return db, db_path
def test_cli_bulk(test_db_and_path):
db, db_path = test_db_and_path
result = CliRunner().invoke(
cli.cli,
[
"bulk",
db_path,
"insert into example (id, name) values (:id, myupper(:name))",
"-",
"--nl",
"--functions",
"myupper = lambda s: s.upper()",
],
input='{"id": 3, "name": "Three"}\n{"id": 4, "name": "Four"}\n',
)
assert result.exit_code == 0, result.output
assert [
{"id": 1, "name": "One"},
{"id": 2, "name": "Two"},
{"id": 3, "name": "THREE"},
{"id": 4, "name": "FOUR"},
] == list(db["example"].rows)
def test_cli_bulk_batch_size(test_db_and_path):
db, db_path = test_db_and_path
proc = subprocess.Popen(
[
sys.executable,
"-m",
"sqlite_utils",
"bulk",
db_path,
"insert into example (id, name) values (:id, :name)",
"-",
"--nl",
"--batch-size",
"2",
],
stdin=subprocess.PIPE,
stdout=sys.stdout,
)
# Writing one record should not commit
proc.stdin.write(b'{"id": 3, "name": "Three"}\n\n')
proc.stdin.flush()
time.sleep(1)
assert db["example"].count == 2
# Writing another should trigger a commit:
proc.stdin.write(b'{"id": 4, "name": "Four"}\n\n')
proc.stdin.flush()
time.sleep(1)
assert db["example"].count == 4
proc.stdin.close()
proc.wait()
assert proc.returncode == 0
def test_cli_bulk_error(test_db_and_path):
_, db_path = test_db_and_path
result = CliRunner().invoke(
cli.cli,
[
"bulk",
db_path,
"insert into example (id, name) value (:id, :name)",
"-",
"--nl",
],
input='{"id": 3, "name": "Three"}',
)
assert result.exit_code == 1
assert result.output == 'Error: near "value": syntax error\n'
</document_content>
</document>
<document index="19">
<source>./tests/test_cli_convert.py</source>
<document_content>
from click.testing import CliRunner
from sqlite_utils import cli
import sqlite_utils
import json
import textwrap
import pathlib
import pytest
@pytest.fixture
def test_db_and_path(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["example"].insert_all(
[
{"id": 1, "dt": "5th October 2019 12:04"},
{"id": 2, "dt": "6th October 2019 00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
],
pk="id",
)
return db, db_path
@pytest.fixture
def fresh_db_and_path(tmpdir):
db_path = str(pathlib.Path(tmpdir) / "data.db")
db = sqlite_utils.Database(db_path)
return db, db_path
@pytest.mark.parametrize(
"code",
[
"return value.replace('October', 'Spooktober')",
# Return is optional:
"value.replace('October', 'Spooktober')",
# Multiple lines are supported:
"v = value.replace('October', 'Spooktober')\nreturn v",
# Can also define a convert() function
"def convert(value): return value.replace('October', 'Spooktober')",
# ... with imports
"import re\n\ndef convert(value): return value.replace('October', 'Spooktober')",
],
)
def test_convert_code(fresh_db_and_path, code):
db, db_path = fresh_db_and_path
db["t"].insert({"text": "October"})
result = CliRunner().invoke(
cli.cli, ["convert", db_path, "t", "text", code], catch_exceptions=False
)
assert result.exit_code == 0, result.output
value = list(db["t"].rows)[0]["text"]
assert value == "Spooktober"
@pytest.mark.parametrize(
"bad_code",
(
"def foo(value)",
"$",
),
)
def test_convert_code_errors(fresh_db_and_path, bad_code):
db, db_path = fresh_db_and_path
db["t"].insert({"text": "October"})
result = CliRunner().invoke(
cli.cli, ["convert", db_path, "t", "text", bad_code], catch_exceptions=False
)
assert result.exit_code == 1
assert result.output == "Error: Could not compile code\n"
def test_convert_import(test_db_and_path):
db, db_path = test_db_and_path
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"dt",
"return re.sub('O..', 'OXX', value)",
"--import",
"re",
],
)
assert result.exit_code == 0, result.output
assert [
{"id": 1, "dt": "5th OXXober 2019 12:04"},
{"id": 2, "dt": "6th OXXober 2019 00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
] == list(db["example"].rows)
def test_convert_import_nested(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["example"].insert({"xml": '<item name="Cleo" />'})
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"xml",
'xml.etree.ElementTree.fromstring(value).attrib["name"]',
"--import",
"xml.etree.ElementTree",
],
)
assert result.exit_code == 0, result.output
assert [
{"xml": "Cleo"},
] == list(db["example"].rows)
def test_convert_dryrun(test_db_and_path):
db, db_path = test_db_and_path
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"dt",
"return re.sub('O..', 'OXX', value)",
"--import",
"re",
"--dry-run",
],
)
assert result.exit_code == 0
assert result.output.strip() == (
"5th October 2019 12:04\n"
" --- becomes:\n"
"5th OXXober 2019 12:04\n"
"\n"
"6th October 2019 00:05:06\n"
" --- becomes:\n"
"6th OXXober 2019 00:05:06\n"
"\n"
"\n"
" --- becomes:\n"
"\n"
"\n"
"None\n"
" --- becomes:\n"
"None\n\n"
"Would affect 4 rows"
)
# But it should not have actually modified the table data
assert list(db["example"].rows) == [
{"id": 1, "dt": "5th October 2019 12:04"},
{"id": 2, "dt": "6th October 2019 00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
]
# Test with a where clause too
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"dt",
"return re.sub('O..', 'OXX', value)",
"--import",
"re",
"--dry-run",
"--where",
"id = :id",
"-p",
"id",
"4",
],
)
assert result.exit_code == 0
assert result.output.strip().split("\n")[-1] == "Would affect 1 row"
def test_convert_multi_dryrun(test_db_and_path):
db_path = test_db_and_path[1]
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"dt",
"{'foo': 'bar', 'baz': 1}",
"--dry-run",
"--multi",
],
)
assert result.exit_code == 0
assert result.output.strip() == (
"5th October 2019 12:04\n"
" --- becomes:\n"
'{"foo": "bar", "baz": 1}\n'
"\n"
"6th October 2019 00:05:06\n"
" --- becomes:\n"
'{"foo": "bar", "baz": 1}\n'
"\n"
"\n"
" --- becomes:\n"
"\n"
"\n"
"None\n"
" --- becomes:\n"
"None\n"
"\n"
"Would affect 4 rows"
)
@pytest.mark.parametrize("drop", (True, False))
def test_convert_output_column(test_db_and_path, drop):
db, db_path = test_db_and_path
args = [
"convert",
db_path,
"example",
"dt",
"value.replace('October', 'Spooktober')",
"--output",
"newcol",
]
if drop:
args += ["--drop"]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
expected = [
{
"id": 1,
"dt": "5th October 2019 12:04",
"newcol": "5th Spooktober 2019 12:04",
},
{
"id": 2,
"dt": "6th October 2019 00:05:06",
"newcol": "6th Spooktober 2019 00:05:06",
},
{"id": 3, "dt": "", "newcol": ""},
{"id": 4, "dt": None, "newcol": None},
]
if drop:
for row in expected:
del row["dt"]
assert list(db["example"].rows) == expected
@pytest.mark.parametrize(
"output_type,expected",
(
("text", [(1, "1"), (2, "2"), (3, "3"), (4, "4")]),
("float", [(1, 1.0), (2, 2.0), (3, 3.0), (4, 4.0)]),
("integer", [(1, 1), (2, 2), (3, 3), (4, 4)]),
(None, [(1, "1"), (2, "2"), (3, "3"), (4, "4")]),
),
)
def test_convert_output_column_output_type(test_db_and_path, output_type, expected):
db, db_path = test_db_and_path
args = [
"convert",
db_path,
"example",
"id",
"value",
"--output",
"new_id",
]
if output_type:
args += ["--output-type", output_type]
result = CliRunner().invoke(
cli.cli,
args,
)
assert result.exit_code == 0, result.output
assert expected == list(db.execute("select id, new_id from example"))
@pytest.mark.parametrize(
"options,expected_error",
[
(
[
"dt",
"id",
"value.replace('October', 'Spooktober')",
"--output",
"newcol",
],
"Cannot use --output with more than one column",
),
(
[
"dt",
"value.replace('October', 'Spooktober')",
"--output",
"newcol",
"--output-type",
"invalid",
],
"Error: Invalid value for '--output-type'",
),
(
[
"value.replace('October', 'Spooktober')",
],
"Missing argument 'COLUMNS...'",
),
],
)
def test_convert_output_error(test_db_and_path, options, expected_error):
db_path = test_db_and_path[1]
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
]
+ options,
)
assert result.exit_code != 0
assert expected_error in result.output
@pytest.mark.parametrize("drop", (True, False))
def test_convert_multi(fresh_db_and_path, drop):
db, db_path = fresh_db_and_path
db["creatures"].insert_all(
[
{"id": 1, "name": "Simon"},
{"id": 2, "name": "Cleo"},
],
pk="id",
)
args = [
"convert",
db_path,
"creatures",
"name",
"--multi",
'{"upper": value.upper(), "lower": value.lower()}',
]
if drop:
args += ["--drop"]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
expected = [
{"id": 1, "name": "Simon", "upper": "SIMON", "lower": "simon"},
{"id": 2, "name": "Cleo", "upper": "CLEO", "lower": "cleo"},
]
if drop:
for row in expected:
del row["name"]
assert list(db["creatures"].rows) == expected
def test_convert_multi_complex_column_types(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["rows"].insert_all(
[
{"id": 1},
{"id": 2},
{"id": 3},
{"id": 4},
],
pk="id",
)
code = textwrap.dedent(
"""
if value == 1:
return {"is_str": "", "is_float": 1.2, "is_int": None}
elif value == 2:
return {"is_float": 1, "is_int": 12}
elif value == 3:
return {"is_bytes": b"blah"}
"""
)
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"rows",
"id",
"--multi",
code,
],
)
assert result.exit_code == 0, result.output
assert list(db["rows"].rows) == [
{"id": 1, "is_str": "", "is_float": 1.2, "is_int": None, "is_bytes": None},
{"id": 2, "is_str": None, "is_float": 1.0, "is_int": 12, "is_bytes": None},
{
"id": 3,
"is_str": None,
"is_float": None,
"is_int": None,
"is_bytes": b"blah",
},
{"id": 4, "is_str": None, "is_float": None, "is_int": None, "is_bytes": None},
]
assert db["rows"].schema == (
"CREATE TABLE [rows] (\n"
" [id] INTEGER PRIMARY KEY\n"
", [is_str] TEXT, [is_float] FLOAT, [is_int] INTEGER, [is_bytes] BLOB)"
)
@pytest.mark.parametrize("delimiter", [None, ";", "-"])
def test_recipe_jsonsplit(tmpdir, delimiter):
db_path = str(pathlib.Path(tmpdir) / "data.db")
db = sqlite_utils.Database(db_path)
db["example"].insert_all(
[
{"id": 1, "tags": (delimiter or ",").join(["foo", "bar"])},
{"id": 2, "tags": (delimiter or ",").join(["bar", "baz"])},
],
pk="id",
)
code = "r.jsonsplit(value)"
if delimiter:
code = 'recipes.jsonsplit(value, delimiter="{}")'.format(delimiter)
args = ["convert", db_path, "example", "tags", code]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
assert list(db["example"].rows) == [
{"id": 1, "tags": '["foo", "bar"]'},
{"id": 2, "tags": '["bar", "baz"]'},
]
@pytest.mark.parametrize(
"type,expected_array",
(
(None, ["1", "2", "3"]),
("float", [1.0, 2.0, 3.0]),
("int", [1, 2, 3]),
),
)
def test_recipe_jsonsplit_type(fresh_db_and_path, type, expected_array):
db, db_path = fresh_db_and_path
db["example"].insert_all(
[
{"id": 1, "records": "1,2,3"},
],
pk="id",
)
code = "r.jsonsplit(value)"
if type:
code = "recipes.jsonsplit(value, type={})".format(type)
args = ["convert", db_path, "example", "records", code]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
assert json.loads(db["example"].get(1)["records"]) == expected_array
@pytest.mark.parametrize("drop", (True, False))
def test_recipe_jsonsplit_output(fresh_db_and_path, drop):
db, db_path = fresh_db_and_path
db["example"].insert_all(
[
{"id": 1, "records": "1,2,3"},
],
pk="id",
)
code = "r.jsonsplit(value)"
args = ["convert", db_path, "example", "records", code, "--output", "tags"]
if drop:
args += ["--drop"]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 0, result.output
expected = {
"id": 1,
"records": "1,2,3",
"tags": '["1", "2", "3"]',
}
if drop:
del expected["records"]
assert db["example"].get(1) == expected
def test_cannot_use_drop_without_multi_or_output(fresh_db_and_path):
args = ["convert", fresh_db_and_path[1], "example", "records", "value", "--drop"]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 1, result.output
assert "Error: --drop can only be used with --output or --multi" in result.output
def test_cannot_use_multi_with_more_than_one_column(fresh_db_and_path):
args = [
"convert",
fresh_db_and_path[1],
"example",
"records",
"othercol",
"value",
"--multi",
]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 1, result.output
assert "Error: Cannot use --multi with more than one column" in result.output
def test_multi_with_bad_function(test_db_and_path):
args = [
"convert",
test_db_and_path[1],
"example",
"dt",
"value.upper()",
"--multi",
]
result = CliRunner().invoke(cli.cli, args)
assert result.exit_code == 1, result.output
assert "When using --multi code must return a Python dictionary" in result.output
def test_convert_where(test_db_and_path):
db, db_path = test_db_and_path
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"example",
"dt",
"str(value).upper()",
"--where",
"id = :id",
"-p",
"id",
2,
],
)
assert result.exit_code == 0, result.output
assert list(db["example"].rows) == [
{"id": 1, "dt": "5th October 2019 12:04"},
{"id": 2, "dt": "6TH OCTOBER 2019 00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
]
def test_convert_where_multi(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["names"].insert_all(
[{"id": 1, "name": "Cleo"}, {"id": 2, "name": "Bants"}], pk="id"
)
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"names",
"name",
'{"upper": value.upper()}',
"--where",
"id = :id",
"-p",
"id",
2,
"--multi",
],
)
assert result.exit_code == 0, result.output
assert list(db["names"].rows) == [
{"id": 1, "name": "Cleo", "upper": None},
{"id": 2, "name": "Bants", "upper": "BANTS"},
]
def test_convert_code_standard_input(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["names"].insert_all([{"id": 1, "name": "Cleo"}], pk="id")
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"names",
"name",
"-",
],
input="value.upper()",
)
assert result.exit_code == 0, result.output
assert list(db["names"].rows) == [
{"id": 1, "name": "CLEO"},
]
def test_convert_hyphen_workaround(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["names"].insert_all([{"id": 1, "name": "Cleo"}], pk="id")
result = CliRunner().invoke(
cli.cli,
["convert", db_path, "names", "name", '"-"'],
)
assert result.exit_code == 0, result.output
assert list(db["names"].rows) == [
{"id": 1, "name": "-"},
]
def test_convert_initialization_pattern(fresh_db_and_path):
db, db_path = fresh_db_and_path
db["names"].insert_all([{"id": 1, "name": "Cleo"}], pk="id")
result = CliRunner().invoke(
cli.cli,
[
"convert",
db_path,
"names",
"name",
"-",
],
input="import random\nrandom.seed(1)\ndef convert(value): return random.randint(0, 100)",
)
assert result.exit_code == 0, result.output
assert list(db["names"].rows) == [
{"id": 1, "name": "17"},
]
@pytest.mark.parametrize(
"no_skip_false,expected",
(
(True, 1),
(False, 0),
),
)
def test_convert_no_skip_false(fresh_db_and_path, no_skip_false, expected):
db, db_path = fresh_db_and_path
args = [
"convert",
db_path,
"t",
"x",
"-",
]
if no_skip_false:
args.append("--no-skip-false")
db["t"].insert_all([{"x": 0}, {"x": 1}])
assert db["t"].get(1)["x"] == 0
assert db["t"].get(2)["x"] == 1
result = CliRunner().invoke(cli.cli, args, input="value + 1")
assert result.exit_code == 0, result.output
assert db["t"].get(1)["x"] == expected
assert db["t"].get(2)["x"] == 2
</document_content>
</document>
<document index="20">
<source>./tests/test_cli_insert.py</source>
<document_content>
from sqlite_utils import cli, Database
from click.testing import CliRunner
import json
import pytest
import subprocess
import sys
import time
def test_insert_simple(tmpdir):
json_path = str(tmpdir / "dog.json")
db_path = str(tmpdir / "dogs.db")
with open(json_path, "w") as fp:
fp.write(json.dumps({"name": "Cleo", "age": 4}))
result = CliRunner().invoke(cli.cli, ["insert", db_path, "dogs", json_path])
assert result.exit_code == 0
assert [{"age": 4, "name": "Cleo"}] == list(
Database(db_path).query("select * from dogs")
)
db = Database(db_path)
assert ["dogs"] == db.table_names()
assert [] == db["dogs"].indexes
def test_insert_from_stdin(tmpdir):
db_path = str(tmpdir / "dogs.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "dogs", "-"],
input=json.dumps({"name": "Cleo", "age": 4}),
)
assert result.exit_code == 0
assert [{"age": 4, "name": "Cleo"}] == list(
Database(db_path).query("select * from dogs")
)
def test_insert_invalid_json_error(tmpdir):
db_path = str(tmpdir / "dogs.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "dogs", "-"],
input="name,age\nCleo,4",
)
assert result.exit_code == 1
assert result.output == (
"Error: Invalid JSON - use --csv for CSV or --tsv for TSV files\n\n"
"JSON error: Expecting value: line 1 column 1 (char 0)\n"
)
def test_insert_json_flatten(tmpdir):
db_path = str(tmpdir / "flat.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "items", "-", "--flatten"],
input=json.dumps({"nested": {"data": 4}}),
)
assert result.exit_code == 0
assert list(Database(db_path).query("select * from items")) == [{"nested_data": 4}]
def test_insert_json_flatten_nl(tmpdir):
db_path = str(tmpdir / "flat.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "items", "-", "--flatten", "--nl"],
input="\n".join(
json.dumps(item)
for item in [{"nested": {"data": 4}}, {"nested": {"other": 3}}]
),
)
assert result.exit_code == 0
assert list(Database(db_path).query("select * from items")) == [
{"nested_data": 4, "nested_other": None},
{"nested_data": None, "nested_other": 3},
]
@pytest.mark.parametrize(
"args,expected_pks",
(
(["--pk", "id"], ["id"]),
(["--pk", "id", "--pk", "name"], ["id", "name"]),
),
)
def test_insert_with_primary_keys(db_path, tmpdir, args, expected_pks):
json_path = str(tmpdir / "dog.json")
with open(json_path, "w") as fp:
fp.write(json.dumps({"id": 1, "name": "Cleo", "age": 4}))
result = CliRunner().invoke(cli.cli, ["insert", db_path, "dogs", json_path] + args)
assert result.exit_code == 0
assert [{"id": 1, "age": 4, "name": "Cleo"}] == list(
Database(db_path).query("select * from dogs")
)
db = Database(db_path)
assert db["dogs"].pks == expected_pks
def test_insert_multiple_with_primary_key(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
dogs = [{"id": i, "name": "Cleo {}".format(i), "age": i + 3} for i in range(1, 21)]
with open(json_path, "w") as fp:
fp.write(json.dumps(dogs))
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id"]
)
assert result.exit_code == 0
db = Database(db_path)
assert dogs == list(db.query("select * from dogs order by id"))
assert ["id"] == db["dogs"].pks
def test_insert_multiple_with_compound_primary_key(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
dogs = [
{"breed": "mixed", "id": i, "name": "Cleo {}".format(i), "age": i + 3}
for i in range(1, 21)
]
with open(json_path, "w") as fp:
fp.write(json.dumps(dogs))
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id", "--pk", "breed"]
)
assert result.exit_code == 0
db = Database(db_path)
assert dogs == list(db.query("select * from dogs order by breed, id"))
assert {"breed", "id"} == set(db["dogs"].pks)
assert (
"CREATE TABLE [dogs] (\n"
" [breed] TEXT,\n"
" [id] INTEGER,\n"
" [name] TEXT,\n"
" [age] INTEGER,\n"
" PRIMARY KEY ([id], [breed])\n"
")"
) == db["dogs"].schema
def test_insert_not_null_default(db_path, tmpdir):
json_path = str(tmpdir / "dogs.json")
dogs = [
{"id": i, "name": "Cleo {}".format(i), "age": i + 3, "score": 10}
for i in range(1, 21)
]
with open(json_path, "w") as fp:
fp.write(json.dumps(dogs))
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "dogs", json_path, "--pk", "id"]
+ ["--not-null", "name", "--not-null", "age"]
+ ["--default", "score", "5", "--default", "age", "1"],
)
assert result.exit_code == 0
db = Database(db_path)
assert (
"CREATE TABLE [dogs] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT NOT NULL,\n"
" [age] INTEGER NOT NULL DEFAULT '1',\n"
" [score] INTEGER DEFAULT '5'\n)"
) == db["dogs"].schema
def test_insert_binary_base64(db_path):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "files", "-"],
input=r'{"content": {"$base64": true, "encoded": "aGVsbG8="}}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
actual = list(db.query("select content from files"))
assert actual == [{"content": b"hello"}]
def test_insert_newline_delimited(db_path):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_json_nl", "-", "--nl"],
input='{"foo": "bar", "n": 1}\n\n{"foo": "baz", "n": 2}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
assert [
{"foo": "bar", "n": 1},
{"foo": "baz", "n": 2},
] == list(db.query("select foo, n from from_json_nl"))
def test_insert_ignore(db_path, tmpdir):
db = Database(db_path)
db["dogs"].insert({"id": 1, "name": "Cleo"}, pk="id")
json_path = str(tmpdir / "dogs.json")
with open(json_path, "w") as fp:
fp.write(json.dumps([{"id": 1, "name": "Bailey"}]))
# Should raise error without --ignore
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id"]
)
assert result.exit_code != 0, result.output
# If we use --ignore it should run OK
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id", "--ignore"]
)
assert result.exit_code == 0, result.output
# ... but it should actually have no effect
assert [{"id": 1, "name": "Cleo"}] == list(db.query("select * from dogs"))
@pytest.mark.parametrize(
"content,options",
[
("foo\tbar\tbaz\n1\t2\tcat,dog", ["--tsv"]),
('foo,bar,baz\n1,2,"cat,dog"', ["--csv"]),
('foo;bar;baz\n1;2;"cat,dog"', ["--csv", "--delimiter", ";"]),
# --delimiter implies --csv:
('foo;bar;baz\n1;2;"cat,dog"', ["--delimiter", ";"]),
("foo,bar,baz\n1,2,|cat,dog|", ["--csv", "--quotechar", "|"]),
("foo,bar,baz\n1,2,|cat,dog|", ["--quotechar", "|"]),
],
)
def test_insert_csv_tsv(content, options, db_path, tmpdir):
db = Database(db_path)
file_path = str(tmpdir / "insert.csv-tsv")
with open(file_path, "w") as fp:
fp.write(content)
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "data", file_path] + options,
catch_exceptions=False,
)
assert result.exit_code == 0
assert [{"foo": "1", "bar": "2", "baz": "cat,dog"}] == list(db["data"].rows)
@pytest.mark.parametrize("empty_null", (True, False))
def test_insert_csv_empty_null(db_path, empty_null):
options = ["--csv"]
if empty_null:
options.append("--empty-null")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "data", "-"] + options,
catch_exceptions=False,
input="foo,bar,baz\n1,,cat,dog",
)
assert result.exit_code == 0
db = Database(db_path)
assert [r for r in db["data"].rows] == [
{"foo": "1", "bar": None if empty_null else "", "baz": "cat"}
]
@pytest.mark.parametrize(
"input,args",
(
(
json.dumps(
[{"name": "One"}, {"name": "Two"}, {"name": "Three"}, {"name": "Four"}]
),
[],
),
("name\nOne\nTwo\nThree\nFour\n", ["--csv"]),
),
)
def test_insert_stop_after(tmpdir, input, args):
db_path = str(tmpdir / "data.db")
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "rows", "-", "--stop-after", "2"] + args,
input=input,
)
assert result.exit_code == 0
assert [{"name": "One"}, {"name": "Two"}] == list(
Database(db_path).query("select * from rows")
)
@pytest.mark.parametrize(
"options",
(
["--tsv", "--nl"],
["--tsv", "--csv"],
["--csv", "--nl"],
["--csv", "--nl", "--tsv"],
),
)
def test_only_allow_one_of_nl_tsv_csv(options, db_path, tmpdir):
file_path = str(tmpdir / "insert.csv-tsv")
with open(file_path, "w") as fp:
fp.write("foo")
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "data", file_path] + options
)
assert result.exit_code != 0
assert "Error: Use just one of --nl, --csv or --tsv" == result.output.strip()
def test_insert_replace(db_path, tmpdir):
test_insert_multiple_with_primary_key(db_path, tmpdir)
json_path = str(tmpdir / "insert-replace.json")
db = Database(db_path)
assert db["dogs"].count == 20
insert_replace_dogs = [
{"id": 1, "name": "Insert replaced 1", "age": 4},
{"id": 2, "name": "Insert replaced 2", "age": 4},
{"id": 21, "name": "Fresh insert 21", "age": 6},
]
with open(json_path, "w") as fp:
fp.write(json.dumps(insert_replace_dogs))
result = CliRunner().invoke(
cli.cli, ["insert", db_path, "dogs", json_path, "--pk", "id", "--replace"]
)
assert result.exit_code == 0, result.output
assert db["dogs"].count == 21
assert (
list(db.query("select * from dogs where id in (1, 2, 21) order by id"))
== insert_replace_dogs
)
def test_insert_truncate(db_path):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_json_nl", "-", "--nl", "--batch-size=1"],
input='{"foo": "bar", "n": 1}\n{"foo": "baz", "n": 2}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
assert [
{"foo": "bar", "n": 1},
{"foo": "baz", "n": 2},
] == list(db.query("select foo, n from from_json_nl"))
# Truncate and insert new rows
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"from_json_nl",
"-",
"--nl",
"--truncate",
"--batch-size=1",
],
input='{"foo": "bam", "n": 3}\n{"foo": "bat", "n": 4}',
)
assert result.exit_code == 0, result.output
assert [
{"foo": "bam", "n": 3},
{"foo": "bat", "n": 4},
] == list(db.query("select foo, n from from_json_nl"))
def test_insert_alter(db_path, tmpdir):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_json_nl", "-", "--nl"],
input='{"foo": "bar", "n": 1}\n{"foo": "baz", "n": 2}',
)
assert result.exit_code == 0, result.output
# Should get an error with incorrect shaped additional data
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_json_nl", "-", "--nl"],
input='{"foo": "bar", "baz": 5}',
)
assert result.exit_code != 0, result.output
# If we run it again with --alter it should work correctly
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_json_nl", "-", "--nl", "--alter"],
input='{"foo": "bar", "baz": 5}',
)
assert result.exit_code == 0, result.output
# Soundness check the database itself
db = Database(db_path)
assert {"foo": str, "n": int, "baz": int} == db["from_json_nl"].columns_dict
assert [
{"foo": "bar", "n": 1, "baz": None},
{"foo": "baz", "n": 2, "baz": None},
{"foo": "bar", "baz": 5, "n": None},
] == list(db.query("select foo, n, baz from from_json_nl"))
def test_insert_analyze(db_path):
db = Database(db_path)
db["rows"].insert({"foo": "x", "n": 3})
db["rows"].create_index(["n"])
assert "sqlite_stat1" not in db.table_names()
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "rows", "-", "--nl", "--analyze"],
input='{"foo": "bar", "n": 1}\n{"foo": "baz", "n": 2}',
)
assert result.exit_code == 0, result.output
assert "sqlite_stat1" in db.table_names()
def test_insert_lines(db_path):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_lines", "-", "--lines"],
input='First line\nSecond line\n{"foo": "baz"}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
assert [
{"line": "First line"},
{"line": "Second line"},
{"line": '{"foo": "baz"}'},
] == list(db.query("select line from from_lines"))
def test_insert_text(db_path):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "from_text", "-", "--text"],
input='First line\nSecond line\n{"foo": "baz"}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
assert [{"text": 'First line\nSecond line\n{"foo": "baz"}'}] == list(
db.query("select text from from_text")
)
@pytest.mark.parametrize(
"options,input",
(
([], '[{"id": "1", "name": "Bob"}, {"id": "2", "name": "Cat"}]'),
(["--csv"], "id,name\n1,Bob\n2,Cat"),
(["--nl"], '{"id": "1", "name": "Bob"}\n{"id": "2", "name": "Cat"}'),
),
)
def test_insert_convert_json_csv_jsonnl(db_path, options, input):
result = CliRunner().invoke(
cli.cli,
["insert", db_path, "rows", "-", "--convert", '{**row, **{"extra": 1}}']
+ options,
input=input,
)
assert result.exit_code == 0, result.output
db = Database(db_path)
rows = list(db.query("select id, name, extra from rows"))
assert rows == [
{"id": "1", "name": "Bob", "extra": 1},
{"id": "2", "name": "Cat", "extra": 1},
]
def test_insert_convert_text(db_path):
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"text",
"-",
"--text",
"--convert",
'{"text": text.upper()}',
],
input="This is text\nwill be upper now",
)
assert result.exit_code == 0, result.output
db = Database(db_path)
rows = list(db.query("select [text] from [text]"))
assert rows == [{"text": "THIS IS TEXT\nWILL BE UPPER NOW"}]
def test_insert_convert_text_returning_iterator(db_path):
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"text",
"-",
"--text",
"--convert",
'({"word": w} for w in text.split())',
],
input="A bunch of words",
)
assert result.exit_code == 0, result.output
db = Database(db_path)
rows = list(db.query("select [word] from [text]"))
assert rows == [{"word": "A"}, {"word": "bunch"}, {"word": "of"}, {"word": "words"}]
def test_insert_convert_lines(db_path):
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"all",
"-",
"--lines",
"--convert",
'{"line": line.upper()}',
],
input="This is text\nwill be upper now",
)
assert result.exit_code == 0, result.output
db = Database(db_path)
rows = list(db.query("select [line] from [all]"))
assert rows == [{"line": "THIS IS TEXT"}, {"line": "WILL BE UPPER NOW"}]
def test_insert_convert_row_modifying_in_place(db_path):
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"rows",
"-",
"--convert",
'row["is_chicken"] = True',
],
input='{"name": "Azi"}',
)
assert result.exit_code == 0, result.output
db = Database(db_path)
rows = list(db.query("select name, is_chicken from rows"))
assert rows == [{"name": "Azi", "is_chicken": 1}]
@pytest.mark.parametrize(
"options,expected_error",
(
(
["--text", "--convert", "1"],
"Error: --convert must return dict or iterator\n",
),
(["--convert", "1"], "Error: Rows must all be dictionaries, got: 1\n"),
),
)
def test_insert_convert_error_messages(db_path, options, expected_error):
result = CliRunner().invoke(
cli.cli,
[
"insert",
db_path,
"rows",
"-",
]
+ options,
input='{"name": "Azi"}',
)
assert result.exit_code == 1
assert result.output == expected_error
def test_insert_streaming_batch_size_1(db_path):
# https://github.com/simonw/sqlite-utils/issues/364
# Streaming with --batch-size 1 should commit on each record
# Can't use CliRunner().invoke() here bacuse we need to
# run assertions in between writing to process stdin
proc = subprocess.Popen(
[
sys.executable,
"-m",
"sqlite_utils",
"insert",
db_path,
"rows",
"-",
"--nl",
"--batch-size",
"1",
],
stdin=subprocess.PIPE,
stdout=sys.stdout,
)
proc.stdin.write(b'{"name": "Azi"}\n')
proc.stdin.flush()
def try_until(expected):
tries = 0
while True:
rows = list(Database(db_path)["rows"].rows)
if rows == expected:
return
tries += 1
if tries > 10:
assert False, "Expected {}, got {}".format(expected, rows)
time.sleep(tries * 0.1)
try_until([{"name": "Azi"}])
proc.stdin.write(b'{"name": "Suna"}\n')
proc.stdin.flush()
try_until([{"name": "Azi"}, {"name": "Suna"}])
proc.stdin.close()
proc.wait()
assert proc.returncode == 0
</document_content>
</document>
<document index="21">
<source>./tests/test_cli_memory.py</source>
<document_content>
import click
import json
import pytest
from click.testing import CliRunner
from sqlite_utils import Database, cli
def test_memory_basic():
result = CliRunner().invoke(cli.cli, ["memory", "select 1 + 1"])
assert result.exit_code == 0
assert result.output.strip() == '[{"1 + 1": 2}]'
@pytest.mark.parametrize("sql_from", ("test", "t", "t1"))
@pytest.mark.parametrize("use_stdin", (True, False))
def test_memory_csv(tmpdir, sql_from, use_stdin):
content = "id,name\n1,Cleo\n2,Bants"
input = None
if use_stdin:
input = content
csv_path = "-"
if sql_from == "test":
sql_from = "stdin"
else:
csv_path = str(tmpdir / "test.csv")
with open(csv_path, "w") as fp:
fp.write(content)
result = CliRunner().invoke(
cli.cli,
["memory", csv_path, "select * from {}".format(sql_from), "--nl"],
input=input,
)
assert result.exit_code == 0
assert (
result.output.strip() == '{"id": 1, "name": "Cleo"}\n{"id": 2, "name": "Bants"}'
)
@pytest.mark.parametrize("use_stdin", (True, False))
def test_memory_tsv(tmpdir, use_stdin):
data = "id\tname\n1\tCleo\n2\tBants"
if use_stdin:
input = data
path = "stdin:tsv"
sql_from = "stdin"
else:
input = None
path = str(tmpdir / "chickens.tsv")
with open(path, "w") as fp:
fp.write(data)
path = path + ":tsv"
sql_from = "chickens"
result = CliRunner().invoke(
cli.cli,
["memory", path, "select * from {}".format(sql_from)],
input=input,
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == [
{"id": 1, "name": "Cleo"},
{"id": 2, "name": "Bants"},
]
@pytest.mark.parametrize("use_stdin", (True, False))
def test_memory_json(tmpdir, use_stdin):
data = '[{"name": "Bants"}, {"name": "Dori", "age": 1, "nested": {"nest": 1}}]'
if use_stdin:
input = data
path = "stdin:json"
sql_from = "stdin"
else:
input = None
path = str(tmpdir / "chickens.json")
with open(path, "w") as fp:
fp.write(data)
path = path + ":json"
sql_from = "chickens"
result = CliRunner().invoke(
cli.cli,
["memory", path, "select * from {}".format(sql_from)],
input=input,
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == [
{"name": "Bants", "age": None, "nested": None},
{"name": "Dori", "age": 1, "nested": '{"nest": 1}'},
]
@pytest.mark.parametrize("use_stdin", (True, False))
def test_memory_json_nl(tmpdir, use_stdin):
data = '{"name": "Bants"}\n\n{"name": "Dori"}'
if use_stdin:
input = data
path = "stdin:nl"
sql_from = "stdin"
else:
input = None
path = str(tmpdir / "chickens.json")
with open(path, "w") as fp:
fp.write(data)
path = path + ":nl"
sql_from = "chickens"
result = CliRunner().invoke(
cli.cli,
["memory", path, "select * from {}".format(sql_from)],
input=input,
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == [
{"name": "Bants"},
{"name": "Dori"},
]
@pytest.mark.parametrize("use_stdin", (True, False))
def test_memory_csv_encoding(tmpdir, use_stdin):
latin1_csv = (
b"date,name,latitude,longitude\n" b"2020-03-04,S\xe3o Paulo,-23.561,-46.645\n"
)
input = None
if use_stdin:
input = latin1_csv
csv_path = "-"
sql_from = "stdin"
else:
csv_path = str(tmpdir / "test.csv")
with open(csv_path, "wb") as fp:
fp.write(latin1_csv)
sql_from = "test"
# Without --encoding should error:
assert (
CliRunner()
.invoke(
cli.cli,
["memory", csv_path, "select * from {}".format(sql_from), "--nl"],
input=input,
)
.exit_code
== 1
)
# With --encoding should work:
result = CliRunner().invoke(
cli.cli,
["memory", "-", "select * from stdin", "--encoding", "latin-1", "--nl"],
input=latin1_csv,
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == {
"date": "2020-03-04",
"name": "São Paulo",
"latitude": -23.561,
"longitude": -46.645,
}
@pytest.mark.parametrize("extra_args", ([], ["select 1"]))
def test_memory_dump(extra_args):
result = CliRunner().invoke(
cli.cli,
["memory", "-"] + extra_args + ["--dump"],
input="id,name\n1,Cleo\n2,Bants",
)
assert result.exit_code == 0
expected = (
"BEGIN TRANSACTION;\n"
'CREATE TABLE IF NOT EXISTS "stdin" (\n'
" [id] INTEGER,\n"
" [name] TEXT\n"
");\n"
"INSERT INTO \"stdin\" VALUES(1,'Cleo');\n"
"INSERT INTO \"stdin\" VALUES(2,'Bants');\n"
"CREATE VIEW t1 AS select * from [stdin];\n"
"CREATE VIEW t AS select * from [stdin];\n"
"COMMIT;"
)
# Using sqlite-dump it won't have IF NOT EXISTS
expected_alternative = expected.replace("IF NOT EXISTS ", "")
assert result.output.strip() in (expected, expected_alternative)
@pytest.mark.parametrize("extra_args", ([], ["select 1"]))
def test_memory_schema(extra_args):
result = CliRunner().invoke(
cli.cli,
["memory", "-"] + extra_args + ["--schema"],
input="id,name\n1,Cleo\n2,Bants",
)
assert result.exit_code == 0
assert result.output.strip() == (
'CREATE TABLE "stdin" (\n'
" [id] INTEGER,\n"
" [name] TEXT\n"
");\n"
"CREATE VIEW t1 AS select * from [stdin];\n"
"CREATE VIEW t AS select * from [stdin];"
)
@pytest.mark.parametrize("extra_args", ([], ["select 1"]))
def test_memory_save(tmpdir, extra_args):
save_to = str(tmpdir / "save.db")
result = CliRunner().invoke(
cli.cli,
["memory", "-"] + extra_args + ["--save", save_to],
input="id,name\n1,Cleo\n2,Bants",
)
assert result.exit_code == 0
db = Database(save_to)
assert list(db["stdin"].rows) == [
{"id": 1, "name": "Cleo"},
{"id": 2, "name": "Bants"},
]
@pytest.mark.parametrize("option", ("-n", "--no-detect-types"))
def test_memory_no_detect_types(option):
result = CliRunner().invoke(
cli.cli,
["memory", "-", "select * from stdin"] + [option],
input="id,name,weight\n1,Cleo,45.5\n2,Bants,3.5",
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == [
{"id": "1", "name": "Cleo", "weight": "45.5"},
{"id": "2", "name": "Bants", "weight": "3.5"},
]
def test_memory_flatten():
result = CliRunner().invoke(
cli.cli,
["memory", "-", "select * from stdin", "--flatten"],
input=json.dumps(
{
"httpRequest": {
"latency": "0.112114537s",
"requestMethod": "GET",
},
"insertId": "6111722f000b5b4c4d4071e2",
}
),
)
assert result.exit_code == 0, result.output
assert json.loads(result.output.strip()) == [
{
"httpRequest_latency": "0.112114537s",
"httpRequest_requestMethod": "GET",
"insertId": "6111722f000b5b4c4d4071e2",
}
]
def test_memory_analyze():
result = CliRunner().invoke(
cli.cli,
["memory", "-", "--analyze"],
input="id,name\n1,Cleo\n2,Bants",
)
assert result.exit_code == 0
assert result.output == (
"stdin.id: (1/2)\n\n"
" Total rows: 2\n"
" Null rows: 0\n"
" Blank rows: 0\n\n"
" Distinct values: 2\n\n"
"stdin.name: (2/2)\n\n"
" Total rows: 2\n"
" Null rows: 0\n"
" Blank rows: 0\n\n"
" Distinct values: 2\n\n"
)
def test_memory_two_files_with_same_stem(tmpdir):
(tmpdir / "one").mkdir()
(tmpdir / "two").mkdir()
one = tmpdir / "one" / "data.csv"
two = tmpdir / "two" / "data.csv"
one.write_text("id,name\n1,Cleo\n2,Bants", encoding="utf-8")
two.write_text("id,name\n3,Blue\n4,Lila", encoding="utf-8")
result = CliRunner().invoke(cli.cli, ["memory", str(one), str(two), "", "--schema"])
assert result.exit_code == 0
assert result.output == (
'CREATE TABLE "data" (\n'
" [id] INTEGER,\n"
" [name] TEXT\n"
");\n"
"CREATE VIEW t1 AS select * from [data];\n"
"CREATE VIEW t AS select * from [data];\n"
'CREATE TABLE "data_2" (\n'
" [id] INTEGER,\n"
" [name] TEXT\n"
");\n"
"CREATE VIEW t2 AS select * from [data_2];\n"
)
def test_memory_functions():
result = CliRunner().invoke(
cli.cli,
["memory", "select hello()", "--functions", "hello = lambda: 'Hello'"],
)
assert result.exit_code == 0
assert result.output.strip() == '[{"hello()": "Hello"}]'
def test_memory_return_db(tmpdir):
# https://github.com/simonw/sqlite-utils/issues/643
from sqlite_utils.cli import cli
path = str(tmpdir / "dogs.csv")
open(path, "w").write("id,name\n1,Cleo")
with click.Context(cli) as ctx:
db = ctx.invoke(cli.commands["memory"], paths=(path,), return_db=True)
assert db.table_names() == ["dogs"]
</document_content>
</document>
<document index="22">
<source>./tests/test_column_affinity.py</source>
<document_content>
import pytest
from sqlite_utils.utils import column_affinity
EXAMPLES = [
# Examples from https://www.sqlite.org/datatype3.html#affinity_name_examples
("INT", int),
("INTEGER", int),
("TINYINT", int),
("SMALLINT", int),
("MEDIUMINT", int),
("BIGINT", int),
("UNSIGNED BIG INT", int),
("INT2", int),
("INT8", int),
("CHARACTER(20)", str),
("VARCHAR(255)", str),
("VARYING CHARACTER(255)", str),
("NCHAR(55)", str),
("NATIVE CHARACTER(70)", str),
("NVARCHAR(100)", str),
("TEXT", str),
("CLOB", str),
("BLOB", bytes),
("REAL", float),
("DOUBLE", float),
("DOUBLE PRECISION", float),
("FLOAT", float),
# Numeric, treated as float:
("NUMERIC", float),
("DECIMAL(10,5)", float),
("BOOLEAN", float),
("DATE", float),
("DATETIME", float),
]
@pytest.mark.parametrize("column_def,expected_type", EXAMPLES)
def test_column_affinity(column_def, expected_type):
assert expected_type is column_affinity(column_def)
@pytest.mark.parametrize("column_def,expected_type", EXAMPLES)
def test_columns_dict(fresh_db, column_def, expected_type):
fresh_db.execute("create table foo (col {})".format(column_def))
assert {"col": expected_type} == fresh_db["foo"].columns_dict
</document_content>
</document>
<document index="23">
<source>./tests/test_constructor.py</source>
<document_content>
from sqlite_utils import Database
from sqlite_utils.utils import sqlite3
import pytest
def test_recursive_triggers():
db = Database(memory=True)
assert db.execute("PRAGMA recursive_triggers").fetchone()[0]
def test_recursive_triggers_off():
db = Database(memory=True, recursive_triggers=False)
assert not db.execute("PRAGMA recursive_triggers").fetchone()[0]
def test_memory_name():
db1 = Database(memory_name="shared")
db2 = Database(memory_name="shared")
db1["dogs"].insert({"name": "Cleo"})
assert list(db2["dogs"].rows) == [{"name": "Cleo"}]
def test_sqlite_version():
db = Database(memory=True)
version = db.sqlite_version
assert isinstance(version, tuple)
as_string = ".".join(map(str, version))
actual = next(db.query("select sqlite_version() as v"))["v"]
assert actual == as_string
@pytest.mark.parametrize("memory", [True, False])
def test_database_close(tmpdir, memory):
if memory:
db = Database(memory=True)
else:
db = Database(str(tmpdir / "test.db"))
assert db.execute("select 1 + 1").fetchone()[0] == 2
db.close()
with pytest.raises(sqlite3.ProgrammingError):
db.execute("select 1 + 1")
</document_content>
</document>
<document index="24">
<source>./tests/test_conversions.py</source>
<document_content>
def test_insert_conversion(fresh_db):
table = fresh_db["table"]
table.insert({"foo": "bar"}, conversions={"foo": "upper(?)"})
assert [{"foo": "BAR"}] == list(table.rows)
def test_insert_all_conversion(fresh_db):
table = fresh_db["table"]
table.insert_all([{"foo": "bar"}], conversions={"foo": "upper(?)"})
assert [{"foo": "BAR"}] == list(table.rows)
def test_upsert_conversion(fresh_db):
table = fresh_db["table"]
table.upsert({"id": 1, "foo": "bar"}, pk="id", conversions={"foo": "upper(?)"})
assert [{"id": 1, "foo": "BAR"}] == list(table.rows)
table.upsert(
{"id": 1, "bar": "baz"}, pk="id", conversions={"bar": "upper(?)"}, alter=True
)
assert [{"id": 1, "foo": "BAR", "bar": "BAZ"}] == list(table.rows)
def test_upsert_all_conversion(fresh_db):
table = fresh_db["table"]
table.upsert_all(
[{"id": 1, "foo": "bar"}], pk="id", conversions={"foo": "upper(?)"}
)
assert [{"id": 1, "foo": "BAR"}] == list(table.rows)
def test_update_conversion(fresh_db):
table = fresh_db["table"]
table.insert({"id": 5, "foo": "bar"}, pk="id")
table.update(5, {"foo": "baz"}, conversions={"foo": "upper(?)"})
assert [{"id": 5, "foo": "BAZ"}] == list(table.rows)
def test_table_constructor_conversion(fresh_db):
table = fresh_db.table("table", conversions={"bar": "upper(?)"})
table.insert({"bar": "baz"})
assert [{"bar": "BAZ"}] == list(table.rows)
</document_content>
</document>
<document index="25">
<source>./tests/test_convert.py</source>
<document_content>
from sqlite_utils.db import BadMultiValues
import pytest
@pytest.mark.parametrize(
"columns,fn,expected",
(
(
"title",
lambda value: value.upper(),
{"title": "MIXED CASE", "abstract": "Abstract"},
),
(
["title", "abstract"],
lambda value: value.upper(),
{"title": "MIXED CASE", "abstract": "ABSTRACT"},
),
(
"title",
lambda value: {"upper": value.upper(), "lower": value.lower()},
{
"title": '{"upper": "MIXED CASE", "lower": "mixed case"}',
"abstract": "Abstract",
},
),
),
)
def test_convert(fresh_db, columns, fn, expected):
table = fresh_db["table"]
table.insert({"title": "Mixed Case", "abstract": "Abstract"})
table.convert(columns, fn)
assert list(table.rows) == [expected]
@pytest.mark.parametrize(
"where,where_args", (("id > 1", None), ("id > :id", {"id": 1}), ("id > ?", [1]))
)
def test_convert_where(fresh_db, where, where_args):
table = fresh_db["table"]
table.insert_all(
[
{"id": 1, "title": "One"},
{"id": 2, "title": "Two"},
],
pk="id",
)
table.convert(
"title", lambda value: value.upper(), where=where, where_args=where_args
)
assert list(table.rows) == [{"id": 1, "title": "One"}, {"id": 2, "title": "TWO"}]
def test_convert_skip_false(fresh_db):
table = fresh_db["table"]
table.insert_all([{"x": 0}, {"x": 1}])
assert table.get(1)["x"] == 0
assert table.get(2)["x"] == 1
table.convert("x", lambda x: x + 1, skip_false=False)
assert table.get(1)["x"] == 1
assert table.get(2)["x"] == 2
@pytest.mark.parametrize(
"drop,expected",
(
(False, {"title": "Mixed Case", "other": "MIXED CASE"}),
(True, {"other": "MIXED CASE"}),
),
)
def test_convert_output(fresh_db, drop, expected):
table = fresh_db["table"]
table.insert({"title": "Mixed Case"})
table.convert("title", lambda v: v.upper(), output="other", drop=drop)
assert list(table.rows) == [expected]
def test_convert_output_multiple_column_error(fresh_db):
table = fresh_db["table"]
with pytest.raises(AssertionError) as excinfo:
table.convert(["title", "other"], lambda v: v, output="out")
assert "output= can only be used with a single column" in str(excinfo.value)
@pytest.mark.parametrize(
"type,expected",
(
(int, {"other": 123}),
(float, {"other": 123.0}),
),
)
def test_convert_output_type(fresh_db, type, expected):
table = fresh_db["table"]
table.insert({"number": "123"})
table.convert("number", lambda v: v, output="other", output_type=type, drop=True)
assert list(table.rows) == [expected]
def test_convert_multi(fresh_db):
table = fresh_db["table"]
table.insert({"title": "Mixed Case"})
table.convert(
"title",
lambda v: {
"upper": v.upper(),
"lower": v.lower(),
"both": {
"upper": v.upper(),
"lower": v.lower(),
},
},
multi=True,
)
assert list(table.rows) == [
{
"title": "Mixed Case",
"upper": "MIXED CASE",
"lower": "mixed case",
"both": '{"upper": "MIXED CASE", "lower": "mixed case"}',
}
]
def test_convert_multi_where(fresh_db):
table = fresh_db["table"]
table.insert_all(
[
{"id": 1, "title": "One"},
{"id": 2, "title": "Two"},
],
pk="id",
)
table.convert(
"title",
lambda v: {"upper": v.upper(), "lower": v.lower()},
multi=True,
where="id > ?",
where_args=[1],
)
assert list(table.rows) == [
{"id": 1, "lower": None, "title": "One", "upper": None},
{"id": 2, "lower": "two", "title": "Two", "upper": "TWO"},
]
def test_convert_multi_exception(fresh_db):
table = fresh_db["table"]
table.insert({"title": "Mixed Case"})
with pytest.raises(BadMultiValues):
table.convert("title", lambda v: v.upper(), multi=True)
def test_convert_repeated(fresh_db):
table = fresh_db["table"]
col = "num"
table.insert({col: 1})
table.convert(col, lambda x: x * 2)
table.convert(col, lambda _x: 0)
assert table.get(1) == {col: 0}
</document_content>
</document>
<document index="26">
<source>./tests/test_create.py</source>
<document_content>
from sqlite_utils.db import (
Index,
Database,
DescIndex,
AlterError,
NoObviousTable,
OperationalError,
ForeignKey,
Table,
View,
)
from sqlite_utils.utils import hash_record, sqlite3
import collections
import datetime
import decimal
import json
import pathlib
import pytest
import uuid
try:
import pandas as pd # type: ignore
except ImportError:
pd = None # type: ignore
def test_create_table(fresh_db):
assert [] == fresh_db.table_names()
table = fresh_db.create_table(
"test_table",
{
"text_col": str,
"float_col": float,
"int_col": int,
"bool_col": bool,
"bytes_col": bytes,
"datetime_col": datetime.datetime,
},
)
assert ["test_table"] == fresh_db.table_names()
assert [
{"name": "text_col", "type": "TEXT"},
{"name": "float_col", "type": "FLOAT"},
{"name": "int_col", "type": "INTEGER"},
{"name": "bool_col", "type": "INTEGER"},
{"name": "bytes_col", "type": "BLOB"},
{"name": "datetime_col", "type": "TEXT"},
] == [{"name": col.name, "type": col.type} for col in table.columns]
assert (
"CREATE TABLE [test_table] (\n"
" [text_col] TEXT,\n"
" [float_col] FLOAT,\n"
" [int_col] INTEGER,\n"
" [bool_col] INTEGER,\n"
" [bytes_col] BLOB,\n"
" [datetime_col] TEXT\n"
")"
) == table.schema
def test_create_table_compound_primary_key(fresh_db):
table = fresh_db.create_table(
"test_table", {"id1": str, "id2": str, "value": int}, pk=("id1", "id2")
)
assert (
"CREATE TABLE [test_table] (\n"
" [id1] TEXT,\n"
" [id2] TEXT,\n"
" [value] INTEGER,\n"
" PRIMARY KEY ([id1], [id2])\n"
")"
) == table.schema
assert ["id1", "id2"] == table.pks
@pytest.mark.parametrize("pk", ("id", ["id"]))
def test_create_table_with_single_primary_key(fresh_db, pk):
fresh_db["foo"].insert({"id": 1}, pk=pk)
assert (
fresh_db["foo"].schema == "CREATE TABLE [foo] (\n [id] INTEGER PRIMARY KEY\n)"
)
def test_create_table_with_invalid_column_characters(fresh_db):
with pytest.raises(AssertionError):
fresh_db.create_table("players", {"name[foo]": str})
def test_create_table_with_defaults(fresh_db):
table = fresh_db.create_table(
"players",
{"name": str, "score": int},
defaults={"score": 1, "name": "bob''bob"},
)
assert ["players"] == fresh_db.table_names()
assert [{"name": "name", "type": "TEXT"}, {"name": "score", "type": "INTEGER"}] == [
{"name": col.name, "type": col.type} for col in table.columns
]
assert (
"CREATE TABLE [players] (\n [name] TEXT DEFAULT 'bob''''bob',\n [score] INTEGER DEFAULT 1\n)"
) == table.schema
def test_create_table_with_bad_not_null(fresh_db):
with pytest.raises(AssertionError):
fresh_db.create_table(
"players", {"name": str, "score": int}, not_null={"mouse"}
)
def test_create_table_with_not_null(fresh_db):
table = fresh_db.create_table(
"players",
{"name": str, "score": int},
not_null={"name", "score"},
defaults={"score": 3},
)
assert ["players"] == fresh_db.table_names()
assert [{"name": "name", "type": "TEXT"}, {"name": "score", "type": "INTEGER"}] == [
{"name": col.name, "type": col.type} for col in table.columns
]
assert (
"CREATE TABLE [players] (\n [name] TEXT NOT NULL,\n [score] INTEGER NOT NULL DEFAULT 3\n)"
) == table.schema
@pytest.mark.parametrize(
"example,expected_columns",
(
(
{"name": "Ravi", "age": 63},
[{"name": "name", "type": "TEXT"}, {"name": "age", "type": "INTEGER"}],
),
(
{"create": "Reserved word", "table": "Another"},
[{"name": "create", "type": "TEXT"}, {"name": "table", "type": "TEXT"}],
),
({"day": datetime.time(11, 0)}, [{"name": "day", "type": "TEXT"}]),
({"decimal": decimal.Decimal("1.2")}, [{"name": "decimal", "type": "FLOAT"}]),
(
{"memoryview": memoryview(b"hello")},
[{"name": "memoryview", "type": "BLOB"}],
),
({"uuid": uuid.uuid4()}, [{"name": "uuid", "type": "TEXT"}]),
({"foo[bar]": 1}, [{"name": "foo_bar_", "type": "INTEGER"}]),
(
{"timedelta": datetime.timedelta(hours=1)},
[{"name": "timedelta", "type": "TEXT"}],
),
),
)
def test_create_table_from_example(fresh_db, example, expected_columns):
people_table = fresh_db["people"]
assert people_table.last_rowid is None
assert people_table.last_pk is None
people_table.insert(example)
assert people_table.last_rowid == 1
assert people_table.last_pk == 1
assert ["people"] == fresh_db.table_names()
assert expected_columns == [
{"name": col.name, "type": col.type} for col in fresh_db["people"].columns
]
def test_create_table_from_example_with_compound_primary_keys(fresh_db):
record = {"name": "Zhang", "group": "staff", "employee_id": 2}
table = fresh_db["people"].insert(record, pk=("group", "employee_id"))
assert ["group", "employee_id"] == table.pks
assert record == table.get(("staff", 2))
@pytest.mark.parametrize(
"method_name", ("insert", "upsert", "insert_all", "upsert_all")
)
def test_create_table_with_custom_columns(fresh_db, method_name):
table = fresh_db["dogs"]
method = getattr(table, method_name)
record = {"id": 1, "name": "Cleo", "age": "5"}
if method_name.endswith("_all"):
record = [record]
method(record, pk="id", columns={"age": int, "weight": float})
assert ["dogs"] == fresh_db.table_names()
expected_columns = [
{"name": "id", "type": "INTEGER"},
{"name": "name", "type": "TEXT"},
{"name": "age", "type": "INTEGER"},
{"name": "weight", "type": "FLOAT"},
]
assert expected_columns == [
{"name": col.name, "type": col.type} for col in table.columns
]
assert [{"id": 1, "name": "Cleo", "age": 5, "weight": None}] == list(table.rows)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_create_table_column_order(fresh_db, use_table_factory):
row = collections.OrderedDict(
(
("zzz", "third"),
("abc", "first"),
("ccc", "second"),
("bbb", "second-to-last"),
("aaa", "last"),
)
)
column_order = ("abc", "ccc", "zzz")
if use_table_factory:
fresh_db.table("table", column_order=column_order).insert(row)
else:
fresh_db["table"].insert(row, column_order=column_order)
assert [
{"name": "abc", "type": "TEXT"},
{"name": "ccc", "type": "TEXT"},
{"name": "zzz", "type": "TEXT"},
{"name": "bbb", "type": "TEXT"},
{"name": "aaa", "type": "TEXT"},
] == [{"name": col.name, "type": col.type} for col in fresh_db["table"].columns]
@pytest.mark.parametrize(
"foreign_key_specification,expected_exception",
(
# You can specify triples, pairs, or a list of columns
((("one_id", "one", "id"), ("two_id", "two", "id")), False),
((("one_id", "one"), ("two_id", "two")), False),
(("one_id", "two_id"), False),
# You can also specify ForeignKey tuples:
(
(
ForeignKey("m2m", "one_id", "one", "id"),
ForeignKey("m2m", "two_id", "two", "id"),
),
False,
),
# If you specify a column that doesn't point to a table, you get an error:
(("one_id", "two_id", "three_id"), NoObviousTable),
# Tuples of the wrong length get an error:
((("one_id", "one", "id", "five"), ("two_id", "two", "id")), AssertionError),
# Likewise a bad column:
((("one_id", "one", "id2"),), AlterError),
# Or a list of dicts
(({"one_id": "one"},), AssertionError),
),
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_create_table_works_for_m2m_with_only_foreign_keys(
fresh_db, foreign_key_specification, expected_exception, use_table_factory
):
if use_table_factory:
fresh_db.table("one", pk="id").insert({"id": 1})
fresh_db.table("two", pk="id").insert({"id": 1})
else:
fresh_db["one"].insert({"id": 1}, pk="id")
fresh_db["two"].insert({"id": 1}, pk="id")
row = {"one_id": 1, "two_id": 1}
def do_it():
if use_table_factory:
fresh_db.table("m2m", foreign_keys=foreign_key_specification).insert(row)
else:
fresh_db["m2m"].insert(row, foreign_keys=foreign_key_specification)
if expected_exception:
with pytest.raises(expected_exception):
do_it()
return
else:
do_it()
assert [
{"name": "one_id", "type": "INTEGER"},
{"name": "two_id", "type": "INTEGER"},
] == [{"name": col.name, "type": col.type} for col in fresh_db["m2m"].columns]
assert sorted(
[
{"column": "one_id", "other_table": "one", "other_column": "id"},
{"column": "two_id", "other_table": "two", "other_column": "id"},
],
key=lambda s: repr(s),
) == sorted(
[
{
"column": fk.column,
"other_table": fk.other_table,
"other_column": fk.other_column,
}
for fk in fresh_db["m2m"].foreign_keys
],
key=lambda s: repr(s),
)
def test_self_referential_foreign_key(fresh_db):
assert [] == fresh_db.table_names()
table = fresh_db.create_table(
"test_table",
columns={
"id": int,
"ref": int,
},
pk="id",
foreign_keys=(("ref", "test_table", "id"),),
)
assert (
"CREATE TABLE [test_table] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [ref] INTEGER REFERENCES [test_table]([id])\n"
")"
) == table.schema
def test_create_error_if_invalid_foreign_keys(fresh_db):
with pytest.raises(AlterError):
fresh_db["one"].insert(
{"id": 1, "ref_id": 3},
pk="id",
foreign_keys=(("ref_id", "bad_table", "bad_column"),),
)
def test_create_error_if_invalid_self_referential_foreign_keys(fresh_db):
with pytest.raises(AlterError) as ex:
fresh_db["one"].insert(
{"id": 1, "ref_id": 3},
pk="id",
foreign_keys=(("ref_id", "one", "bad_column"),),
)
assert ex.value.args == ("No such column: one.bad_column",)
@pytest.mark.parametrize(
"col_name,col_type,not_null_default,expected_schema",
(
(
"nickname",
str,
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [nickname] TEXT)",
),
(
"dob",
datetime.date,
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [dob] TEXT)",
),
("age", int, None, "CREATE TABLE [dogs] (\n [name] TEXT\n, [age] INTEGER)"),
(
"weight",
float,
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [weight] FLOAT)",
),
("text", "TEXT", None, "CREATE TABLE [dogs] (\n [name] TEXT\n, [text] TEXT)"),
(
"integer",
"INTEGER",
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [integer] INTEGER)",
),
(
"float",
"FLOAT",
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [float] FLOAT)",
),
("blob", "blob", None, "CREATE TABLE [dogs] (\n [name] TEXT\n, [blob] BLOB)"),
(
"default_str",
None,
None,
"CREATE TABLE [dogs] (\n [name] TEXT\n, [default_str] TEXT)",
),
(
"nickname",
str,
"",
"CREATE TABLE [dogs] (\n [name] TEXT\n, [nickname] TEXT NOT NULL DEFAULT '')",
),
(
"nickname",
str,
"dawg's dawg",
"CREATE TABLE [dogs] (\n [name] TEXT\n, [nickname] TEXT NOT NULL DEFAULT 'dawg''s dawg')",
),
),
)
def test_add_column(fresh_db, col_name, col_type, not_null_default, expected_schema):
fresh_db.create_table("dogs", {"name": str})
assert fresh_db["dogs"].schema == "CREATE TABLE [dogs] (\n [name] TEXT\n)"
fresh_db["dogs"].add_column(col_name, col_type, not_null_default=not_null_default)
assert fresh_db["dogs"].schema == expected_schema
def test_add_foreign_key(fresh_db):
fresh_db["authors"].insert_all(
[{"id": 1, "name": "Sally"}, {"id": 2, "name": "Asheesh"}], pk="id"
)
fresh_db["books"].insert_all(
[
{"title": "Hedgehogs of the world", "author_id": 1},
{"title": "How to train your wolf", "author_id": 2},
]
)
assert [] == fresh_db["books"].foreign_keys
t = fresh_db["books"].add_foreign_key("author_id", "authors", "id")
# Ensure it returned self:
assert isinstance(t, Table) and t.name == "books"
assert [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
)
] == fresh_db["books"].foreign_keys
def test_add_foreign_key_if_column_contains_space(fresh_db):
fresh_db["authors"].insert_all([{"id": 1, "name": "Sally"}], pk="id")
fresh_db["books"].insert_all(
[
{"title": "Hedgehogs of the world", "author id": 1},
]
)
fresh_db["books"].add_foreign_key("author id", "authors", "id")
assert fresh_db["books"].foreign_keys == [
ForeignKey(
table="books", column="author id", other_table="authors", other_column="id"
)
]
def test_add_foreign_key_error_if_column_does_not_exist(fresh_db):
fresh_db["books"].insert(
{"id": 1, "title": "Hedgehogs of the world", "author_id": 1}
)
with pytest.raises(AlterError):
fresh_db["books"].add_foreign_key("author2_id", "books", "id")
def test_add_foreign_key_error_if_other_table_does_not_exist(fresh_db):
fresh_db["books"].insert({"title": "Hedgehogs of the world", "author_id": 1})
with pytest.raises(AlterError):
fresh_db["books"].add_foreign_key("author_id", "authors", "id")
def test_add_foreign_key_error_if_already_exists(fresh_db):
fresh_db["books"].insert({"title": "Hedgehogs of the world", "author_id": 1})
fresh_db["authors"].insert({"id": 1, "name": "Sally"}, pk="id")
fresh_db["books"].add_foreign_key("author_id", "authors", "id")
with pytest.raises(AlterError) as ex:
fresh_db["books"].add_foreign_key("author_id", "authors", "id")
assert "Foreign key already exists for author_id => authors.id" == ex.value.args[0]
def test_add_foreign_key_no_error_if_exists_and_ignore_true(fresh_db):
fresh_db["books"].insert({"title": "Hedgehogs of the world", "author_id": 1})
fresh_db["authors"].insert({"id": 1, "name": "Sally"}, pk="id")
fresh_db["books"].add_foreign_key("author_id", "authors", "id")
fresh_db["books"].add_foreign_key("author_id", "authors", "id", ignore=True)
def test_add_foreign_keys(fresh_db):
fresh_db["authors"].insert_all(
[{"id": 1, "name": "Sally"}, {"id": 2, "name": "Asheesh"}], pk="id"
)
fresh_db["categories"].insert_all([{"id": 1, "name": "Wildlife"}], pk="id")
fresh_db["books"].insert_all(
[{"title": "Hedgehogs of the world", "author_id": 1, "category_id": 1}]
)
assert [] == fresh_db["books"].foreign_keys
fresh_db.add_foreign_keys(
[
("books", "author_id", "authors", "id"),
("books", "category_id", "categories", "id"),
]
)
assert [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
),
ForeignKey(
table="books",
column="category_id",
other_table="categories",
other_column="id",
),
] == sorted(fresh_db["books"].foreign_keys)
def test_add_column_foreign_key(fresh_db):
fresh_db.create_table("dogs", {"name": str})
fresh_db.create_table("breeds", {"name": str})
fresh_db["dogs"].add_column("breed_id", fk="breeds")
assert fresh_db["dogs"].schema == (
'CREATE TABLE "dogs" (\n'
" [name] TEXT,\n"
" [breed_id] INTEGER REFERENCES [breeds]([rowid])\n"
")"
)
# And again with an explicit primary key column
fresh_db.create_table("subbreeds", {"name": str, "primkey": str}, pk="primkey")
fresh_db["dogs"].add_column("subbreed_id", fk="subbreeds")
assert fresh_db["dogs"].schema == (
'CREATE TABLE "dogs" (\n'
" [name] TEXT,\n"
" [breed_id] INTEGER REFERENCES [breeds]([rowid]),\n"
" [subbreed_id] TEXT REFERENCES [subbreeds]([primkey])\n"
")"
)
def test_add_foreign_key_guess_table(fresh_db):
fresh_db.create_table("dogs", {"name": str})
fresh_db.create_table("breeds", {"name": str, "id": int}, pk="id")
fresh_db["dogs"].add_column("breed_id", int)
fresh_db["dogs"].add_foreign_key("breed_id")
assert fresh_db["dogs"].schema == (
'CREATE TABLE "dogs" (\n'
" [name] TEXT,\n"
" [breed_id] INTEGER REFERENCES [breeds]([id])\n"
")"
)
def test_index_foreign_keys(fresh_db):
test_add_foreign_key_guess_table(fresh_db)
assert [] == fresh_db["dogs"].indexes
fresh_db.index_foreign_keys()
assert [["breed_id"]] == [i.columns for i in fresh_db["dogs"].indexes]
# Calling it a second time should do nothing
fresh_db.index_foreign_keys()
assert [["breed_id"]] == [i.columns for i in fresh_db["dogs"].indexes]
def test_index_foreign_keys_if_index_name_is_already_used(fresh_db):
# https://github.com/simonw/sqlite-utils/issues/335
test_add_foreign_key_guess_table(fresh_db)
# Add index with a name that will conflict with index_foreign_keys()
fresh_db["dogs"].create_index(["name"], index_name="idx_dogs_breed_id")
fresh_db.index_foreign_keys()
assert {(idx.name, tuple(idx.columns)) for idx in fresh_db["dogs"].indexes} == {
("idx_dogs_breed_id_2", ("breed_id",)),
("idx_dogs_breed_id", ("name",)),
}
@pytest.mark.parametrize(
"extra_data,expected_new_columns",
[
({"species": "squirrels"}, [{"name": "species", "type": "TEXT"}]),
(
{"species": "squirrels", "hats": 5},
[{"name": "species", "type": "TEXT"}, {"name": "hats", "type": "INTEGER"}],
),
(
{"hats": 5, "rating": 3.5},
[{"name": "hats", "type": "INTEGER"}, {"name": "rating", "type": "FLOAT"}],
),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_insert_row_alter_table(
fresh_db, extra_data, expected_new_columns, use_table_factory
):
table = fresh_db["books"]
table.insert({"title": "Hedgehogs of the world", "author_id": 1})
assert [
{"name": "title", "type": "TEXT"},
{"name": "author_id", "type": "INTEGER"},
] == [{"name": col.name, "type": col.type} for col in table.columns]
record = {"title": "Squirrels of the world", "author_id": 2}
record.update(extra_data)
if use_table_factory:
fresh_db.table("books", alter=True).insert(record)
else:
fresh_db["books"].insert(record, alter=True)
assert [
{"name": "title", "type": "TEXT"},
{"name": "author_id", "type": "INTEGER"},
] + expected_new_columns == [
{"name": col.name, "type": col.type} for col in table.columns
]
def test_add_missing_columns_case_insensitive(fresh_db):
table = fresh_db["foo"]
table.insert({"id": 1, "name": "Cleo"}, pk="id")
table.add_missing_columns([{"Name": ".", "age": 4}])
assert (
table.schema
== "CREATE TABLE [foo] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n, [age] INTEGER)"
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_insert_replace_rows_alter_table(fresh_db, use_table_factory):
first_row = {"id": 1, "title": "Hedgehogs of the world", "author_id": 1}
next_rows = [
{"id": 1, "title": "Hedgehogs of the World", "species": "hedgehogs"},
{"id": 2, "title": "Squirrels of the World", "num_species": 200},
{
"id": 3,
"title": "Badgers of the World",
"significant_continents": ["Europe", "North America"],
},
]
if use_table_factory:
table = fresh_db.table("books", pk="id", alter=True)
table.insert(first_row)
table.insert_all(next_rows, replace=True)
else:
table = fresh_db["books"]
table.insert(first_row, pk="id")
table.insert_all(next_rows, alter=True, replace=True)
assert {
"author_id": int,
"id": int,
"num_species": int,
"significant_continents": str,
"species": str,
"title": str,
} == table.columns_dict
assert [
{
"author_id": None,
"id": 1,
"num_species": None,
"significant_continents": None,
"species": "hedgehogs",
"title": "Hedgehogs of the World",
},
{
"author_id": None,
"id": 2,
"num_species": 200,
"significant_continents": None,
"species": None,
"title": "Squirrels of the World",
},
{
"author_id": None,
"id": 3,
"num_species": None,
"significant_continents": '["Europe", "North America"]',
"species": None,
"title": "Badgers of the World",
},
] == list(table.rows)
def test_insert_all_with_extra_columns_in_later_chunks(fresh_db):
chunk = [
{"record": "Record 1"},
{"record": "Record 2"},
{"record": "Record 3"},
{"record": "Record 4", "extra": 1},
]
fresh_db["t"].insert_all(chunk, batch_size=2, alter=True)
assert list(fresh_db["t"].rows) == [
{"record": "Record 1", "extra": None},
{"record": "Record 2", "extra": None},
{"record": "Record 3", "extra": None},
{"record": "Record 4", "extra": 1},
]
def test_bulk_insert_more_than_999_values(fresh_db):
"Inserting 100 items with 11 columns should work"
fresh_db["big"].insert_all(
(
{
"id": i + 1,
"c2": 2,
"c3": 3,
"c4": 4,
"c5": 5,
"c6": 6,
"c7": 7,
"c8": 8,
"c9": 9,
"c10": 10,
"c11": 11,
}
for i in range(100)
),
pk="id",
)
assert fresh_db["big"].count == 100
@pytest.mark.parametrize(
"num_columns,should_error", ((900, False), (999, False), (1000, True))
)
def test_error_if_more_than_999_columns(fresh_db, num_columns, should_error):
record = dict([("c{}".format(i), i) for i in range(num_columns)])
if should_error:
with pytest.raises(AssertionError):
fresh_db["big"].insert(record)
else:
fresh_db["big"].insert(record)
def test_columns_not_in_first_record_should_not_cause_batch_to_be_too_large(fresh_db):
# https://github.com/simonw/sqlite-utils/issues/145
# sqlite on homebrew and Debian/Ubuntu etc. is typically compiled with
# SQLITE_MAX_VARIABLE_NUMBER set to 250,000, so we need to exceed this value to
# trigger the error on these systems.
THRESHOLD = 250000
batch_size = 999
extra_columns = 1 + (THRESHOLD - 1) // (batch_size - 1)
records = [
{"c0": "first record"}, # one column in first record -> batch size = 999
# fill out the batch with 99 records with enough columns to exceed THRESHOLD
*[
dict([("c{}".format(i), j) for i in range(extra_columns)])
for j in range(batch_size - 1)
],
]
try:
fresh_db["too_many_columns"].insert_all(
records, alter=True, batch_size=batch_size
)
except sqlite3.OperationalError:
raise
@pytest.mark.parametrize(
"columns,index_name,expected_index",
(
(
["is good dog"],
None,
Index(
seq=0,
name="idx_dogs_is good dog",
unique=0,
origin="c",
partial=0,
columns=["is good dog"],
),
),
(
["is good dog", "age"],
None,
Index(
seq=0,
name="idx_dogs_is good dog_age",
unique=0,
origin="c",
partial=0,
columns=["is good dog", "age"],
),
),
(
["age"],
"age_index",
Index(
seq=0,
name="age_index",
unique=0,
origin="c",
partial=0,
columns=["age"],
),
),
),
)
def test_create_index(fresh_db, columns, index_name, expected_index):
dogs = fresh_db["dogs"]
dogs.insert({"name": "Cleo", "twitter": "cleopaws", "age": 3, "is good dog": True})
assert [] == dogs.indexes
dogs.create_index(columns, index_name)
assert expected_index == dogs.indexes[0]
def test_create_index_unique(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"name": "Cleo", "twitter": "cleopaws", "age": 3, "is_good_dog": True})
assert [] == dogs.indexes
dogs.create_index(["name"], unique=True)
assert (
Index(
seq=0,
name="idx_dogs_name",
unique=1,
origin="c",
partial=0,
columns=["name"],
)
== dogs.indexes[0]
)
def test_create_index_if_not_exists(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"name": "Cleo", "twitter": "cleopaws", "age": 3, "is_good_dog": True})
assert [] == dogs.indexes
dogs.create_index(["name"])
assert len(dogs.indexes) == 1
with pytest.raises(Exception, match="index idx_dogs_name already exists"):
dogs.create_index(["name"])
dogs.create_index(["name"], if_not_exists=True)
def test_create_index_desc(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"name": "Cleo", "twitter": "cleopaws", "age": 3, "is good dog": True})
assert [] == dogs.indexes
dogs.create_index([DescIndex("age"), "name"])
sql = fresh_db.execute(
"select sql from sqlite_master where name='idx_dogs_age_name'"
).fetchone()[0]
assert sql == (
"CREATE INDEX [idx_dogs_age_name]\n" " ON [dogs] ([age] desc, [name])"
)
def test_create_index_find_unique_name(fresh_db):
table = fresh_db["t"]
table.insert({"id": 1})
table.create_index(["id"])
# Without find_unique_name should error
with pytest.raises(OperationalError, match="index idx_t_id already exists"):
table.create_index(["id"])
# With find_unique_name=True it should work
table.create_index(["id"], find_unique_name=True)
table.create_index(["id"], find_unique_name=True)
# Should have three now
index_names = {idx.name for idx in table.indexes}
assert index_names == {"idx_t_id", "idx_t_id_2", "idx_t_id_3"}
def test_create_index_analyze(fresh_db):
dogs = fresh_db["dogs"]
assert "sqlite_stat1" not in fresh_db.table_names()
dogs.insert({"name": "Cleo", "twitter": "cleopaws"})
dogs.create_index(["name"], analyze=True)
assert "sqlite_stat1" in fresh_db.table_names()
assert list(fresh_db["sqlite_stat1"].rows) == [
{"tbl": "dogs", "idx": "idx_dogs_name", "stat": "1 1"}
]
@pytest.mark.parametrize(
"data_structure",
(
["list with one item"],
["list with", "two items"],
{"dictionary": "simple"},
{"dictionary": {"nested": "complex"}},
collections.OrderedDict(
[
("key1", {"nested": ["cømplex"]}),
("key2", "foo"),
]
),
[{"list": "of"}, {"two": "dicts"}],
),
)
def test_insert_dictionaries_and_lists_as_json(fresh_db, data_structure):
fresh_db["test"].insert({"id": 1, "data": data_structure}, pk="id")
row = fresh_db.execute("select id, data from test").fetchone()
assert row[0] == 1
assert data_structure == json.loads(row[1])
def test_insert_list_nested_unicode(fresh_db):
fresh_db["test"].insert(
{"id": 1, "data": {"key1": {"nested": ["cømplex"]}}}, pk="id"
)
row = fresh_db.execute("select id, data from test").fetchone()
assert row[1] == '{"key1": {"nested": ["cømplex"]}}'
def test_insert_uuid(fresh_db):
uuid4 = uuid.uuid4()
fresh_db["test"].insert({"uuid": uuid4})
row = list(fresh_db["test"].rows)[0]
assert {"uuid"} == row.keys()
assert isinstance(row["uuid"], str)
assert row["uuid"] == str(uuid4)
def test_insert_memoryview(fresh_db):
fresh_db["test"].insert({"data": memoryview(b"hello")})
row = list(fresh_db["test"].rows)[0]
assert {"data"} == row.keys()
assert isinstance(row["data"], bytes)
assert row["data"] == b"hello"
def test_insert_thousands_using_generator(fresh_db):
fresh_db["test"].insert_all(
{"i": i, "word": "word_{}".format(i)} for i in range(10000)
)
assert [{"name": "i", "type": "INTEGER"}, {"name": "word", "type": "TEXT"}] == [
{"name": col.name, "type": col.type} for col in fresh_db["test"].columns
]
assert fresh_db["test"].count == 10000
def test_insert_thousands_raises_exception_with_extra_columns_after_first_100(fresh_db):
# https://github.com/simonw/sqlite-utils/issues/139
with pytest.raises(Exception, match="table test has no column named extra"):
fresh_db["test"].insert_all(
[{"i": i, "word": "word_{}".format(i)} for i in range(100)]
+ [{"i": 101, "extra": "This extra column should cause an exception"}],
)
def test_insert_thousands_adds_extra_columns_after_first_100_with_alter(fresh_db):
# https://github.com/simonw/sqlite-utils/issues/139
fresh_db["test"].insert_all(
[{"i": i, "word": "word_{}".format(i)} for i in range(100)]
+ [{"i": 101, "extra": "Should trigger ALTER"}],
alter=True,
)
rows = list(fresh_db.query("select * from test where i = 101"))
assert rows == [{"i": 101, "word": None, "extra": "Should trigger ALTER"}]
def test_insert_ignore(fresh_db):
fresh_db["test"].insert({"id": 1, "bar": 2}, pk="id")
# Should raise an error if we try this again
with pytest.raises(Exception, match="UNIQUE constraint failed"):
fresh_db["test"].insert({"id": 1, "bar": 2}, pk="id")
# Using ignore=True should cause our insert to be silently ignored
fresh_db["test"].insert({"id": 1, "bar": 3}, pk="id", ignore=True)
# Only one row, and it should be bar=2, not bar=3
rows = list(fresh_db.query("select * from test"))
assert rows == [{"id": 1, "bar": 2}]
def test_insert_hash_id(fresh_db):
dogs = fresh_db["dogs"]
id = dogs.insert({"name": "Cleo", "twitter": "cleopaws"}, hash_id="id").last_pk
assert "f501265970505d9825d8d9f590bfab3519fb20b1" == id
assert dogs.count == 1
# Insert replacing a second time should not create a new row
id2 = dogs.insert(
{"name": "Cleo", "twitter": "cleopaws"}, hash_id="id", replace=True
).last_pk
assert "f501265970505d9825d8d9f590bfab3519fb20b1" == id2
assert dogs.count == 1
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_insert_hash_id_columns(fresh_db, use_table_factory):
if use_table_factory:
dogs = fresh_db.table("dogs", hash_id_columns=("name", "twitter"))
insert_kwargs = {}
else:
dogs = fresh_db["dogs"]
insert_kwargs = dict(hash_id_columns=("name", "twitter"))
id = dogs.insert(
{"name": "Cleo", "twitter": "cleopaws", "age": 5},
**insert_kwargs,
).last_pk
expected_hash = hash_record({"name": "Cleo", "twitter": "cleopaws"})
assert id == expected_hash
assert dogs.count == 1
# Insert replacing a second time should not create a new row
id2 = dogs.insert(
{"name": "Cleo", "twitter": "cleopaws", "age": 6},
**insert_kwargs,
replace=True,
).last_pk
assert id2 == expected_hash
assert dogs.count == 1
def test_vacuum(fresh_db):
fresh_db["data"].insert({"foo": "foo", "bar": "bar"})
fresh_db.vacuum()
def test_works_with_pathlib_path(tmpdir):
path = pathlib.Path(tmpdir / "test.db")
db = Database(path)
db["demo"].insert_all([{"foo": 1}])
assert db["demo"].count == 1
@pytest.mark.skipif(pd is None, reason="pandas and numpy are not installed")
def test_create_table_numpy(fresh_db):
df = pd.DataFrame({"col 1": range(3), "col 2": range(3)})
fresh_db["pandas"].insert_all(df.to_dict(orient="records"))
assert [
{"col 1": 0, "col 2": 0},
{"col 1": 1, "col 2": 1},
{"col 1": 2, "col 2": 2},
] == list(fresh_db["pandas"].rows)
# Now try all the different types
df = pd.DataFrame(
{
"np.int8": [-8],
"np.int16": [-16],
"np.int32": [-32],
"np.int64": [-64],
"np.uint8": [8],
"np.uint16": [16],
"np.uint32": [32],
"np.uint64": [64],
"np.float16": [16.5],
"np.float32": [32.5],
"np.float64": [64.5],
}
)
df = df.astype(
{
"np.int8": "int8",
"np.int16": "int16",
"np.int32": "int32",
"np.int64": "int64",
"np.uint8": "uint8",
"np.uint16": "uint16",
"np.uint32": "uint32",
"np.uint64": "uint64",
"np.float16": "float16",
"np.float32": "float32",
"np.float64": "float64",
}
)
assert [
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
] == [str(t) for t in df.dtypes]
fresh_db["types"].insert_all(df.to_dict(orient="records"))
assert [
{
"np.float16": 16.5,
"np.float32": 32.5,
"np.float64": 64.5,
"np.int16": -16,
"np.int32": -32,
"np.int64": -64,
"np.int8": -8,
"np.uint16": 16,
"np.uint32": 32,
"np.uint64": 64,
"np.uint8": 8,
}
] == list(fresh_db["types"].rows)
def test_cannot_provide_both_filename_and_memory():
with pytest.raises(
AssertionError, match="Either specify a filename_or_conn or pass memory=True"
):
Database("/tmp/foo.db", memory=True)
def test_creates_id_column(fresh_db):
last_pk = fresh_db.table("cats", pk="id").insert({"name": "barry"}).last_pk
assert [{"name": "barry", "id": last_pk}] == list(fresh_db["cats"].rows)
def test_drop(fresh_db):
fresh_db["t"].insert({"foo": 1})
assert ["t"] == fresh_db.table_names()
assert None is fresh_db["t"].drop()
assert [] == fresh_db.table_names()
def test_drop_view(fresh_db):
fresh_db.create_view("foo_view", "select 1")
assert ["foo_view"] == fresh_db.view_names()
assert None is fresh_db["foo_view"].drop()
assert [] == fresh_db.view_names()
def test_drop_ignore(fresh_db):
with pytest.raises(sqlite3.OperationalError):
fresh_db["does_not_exist"].drop()
fresh_db["does_not_exist"].drop(ignore=True)
# Testing view is harder, we need to create it in order
# to get a View object, then drop it twice
fresh_db.create_view("foo_view", "select 1")
view = fresh_db["foo_view"]
assert isinstance(view, View)
view.drop()
with pytest.raises(sqlite3.OperationalError):
view.drop()
view.drop(ignore=True)
def test_insert_all_empty_list(fresh_db):
fresh_db["t"].insert({"foo": 1})
assert fresh_db["t"].count == 1
fresh_db["t"].insert_all([])
assert fresh_db["t"].count == 1
fresh_db["t"].insert_all([], replace=True)
assert fresh_db["t"].count == 1
def test_insert_all_single_column(fresh_db):
table = fresh_db["table"]
table.insert_all([{"name": "Cleo"}], pk="name")
assert [{"name": "Cleo"}] == list(table.rows)
assert table.pks == ["name"]
@pytest.mark.parametrize("method_name", ("insert_all", "upsert_all"))
def test_insert_all_analyze(fresh_db, method_name):
table = fresh_db["table"]
table.insert_all([{"id": 1, "name": "Cleo"}], pk="id")
assert "sqlite_stat1" not in fresh_db.table_names()
table.create_index(["name"], analyze=True)
assert list(fresh_db["sqlite_stat1"].rows) == [
{"tbl": "table", "idx": "idx_table_name", "stat": "1 1"}
]
method = getattr(table, method_name)
method([{"id": 2, "name": "Suna"}], pk="id", analyze=True)
assert "sqlite_stat1" in fresh_db.table_names()
assert list(fresh_db["sqlite_stat1"].rows) == [
{"tbl": "table", "idx": "idx_table_name", "stat": "2 1"}
]
def test_create_with_a_null_column(fresh_db):
record = {"name": "Name", "description": None}
fresh_db["t"].insert(record)
assert [record] == list(fresh_db["t"].rows)
def test_create_with_nested_bytes(fresh_db):
record = {"id": 1, "data": {"foo": b"bytes"}}
fresh_db["t"].insert(record)
assert [{"id": 1, "data": '{"foo": "b\'bytes\'"}'}] == list(fresh_db["t"].rows)
@pytest.mark.parametrize(
"input,expected", [("hello", "'hello'"), ("hello'there'", "'hello''there'''")]
)
def test_quote(fresh_db, input, expected):
assert fresh_db.quote(input) == expected
@pytest.mark.parametrize(
"columns,expected_sql_middle",
(
(
{"id": int},
"[id] INTEGER",
),
(
{"col": dict},
"[col] TEXT",
),
(
{"col": tuple},
"[col] TEXT",
),
(
{"col": list},
"[col] TEXT",
),
),
)
def test_create_table_sql(fresh_db, columns, expected_sql_middle):
sql = fresh_db.create_table_sql("t", columns)
middle = sql.split("(")[1].split(")")[0].strip()
assert middle == expected_sql_middle
def test_create(fresh_db):
fresh_db["t"].create(
{
"id": int,
"text": str,
"float": float,
"integer": int,
"bytes": bytes,
},
pk="id",
column_order=("id", "float"),
not_null=("float", "integer"),
defaults={"integer": 0},
)
assert fresh_db["t"].schema == (
"CREATE TABLE [t] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [float] FLOAT NOT NULL,\n"
" [text] TEXT,\n"
" [integer] INTEGER NOT NULL DEFAULT 0,\n"
" [bytes] BLOB\n"
")"
)
def test_create_if_not_exists(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"id": int}, if_not_exists=True)
def test_create_if_no_columns(fresh_db):
with pytest.raises(AssertionError) as error:
fresh_db["t"].create({})
assert error.value.args[0] == "Tables must have at least one column"
def test_create_ignore(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"id": int}, ignore=True)
def test_create_replace(fresh_db):
fresh_db["t"].create({"id": int})
# This should error
with pytest.raises(sqlite3.OperationalError):
fresh_db["t"].create({"id": int})
# This should not
fresh_db["t"].create({"name": str}, replace=True)
assert fresh_db["t"].schema == ("CREATE TABLE [t] (\n" " [name] TEXT\n" ")")
@pytest.mark.parametrize(
"cols,kwargs,expected_schema,should_transform",
(
# Change nothing
(
{"id": int, "name": str},
{"pk": "id"},
"CREATE TABLE [demo] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n)",
False,
),
# Drop name column, remove primary key
({"id": int}, {}, 'CREATE TABLE "demo" (\n [id] INTEGER\n)', True),
# Add a new column
(
{"id": int, "name": str, "age": int},
{"pk": "id"},
'CREATE TABLE "demo" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] INTEGER\n)',
True,
),
# Change a column type
(
{"id": int, "name": bytes},
{"pk": "id"},
'CREATE TABLE "demo" (\n [id] INTEGER PRIMARY KEY,\n [name] BLOB\n)',
True,
),
# Change the primary key
(
{"id": int, "name": str},
{"pk": "name"},
'CREATE TABLE "demo" (\n [id] INTEGER,\n [name] TEXT PRIMARY KEY\n)',
True,
),
# Change in column order
(
{"id": int, "name": str},
{"pk": "id", "column_order": ["name"]},
'CREATE TABLE "demo" (\n [name] TEXT,\n [id] INTEGER PRIMARY KEY\n)',
True,
),
# Same column order is ignored
(
{"id": int, "name": str},
{"pk": "id", "column_order": ["id", "name"]},
"CREATE TABLE [demo] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n)",
False,
),
# Change not null
(
{"id": int, "name": str},
{"pk": "id", "not_null": {"name"}},
'CREATE TABLE "demo" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT NOT NULL\n)',
True,
),
# Change default values
(
{"id": int, "name": str},
{"pk": "id", "defaults": {"id": 0, "name": "Bob"}},
"CREATE TABLE \"demo\" (\n [id] INTEGER PRIMARY KEY DEFAULT 0,\n [name] TEXT DEFAULT 'Bob'\n)",
True,
),
),
)
def test_create_transform(fresh_db, cols, kwargs, expected_schema, should_transform):
fresh_db.create_table("demo", {"id": int, "name": str}, pk="id")
fresh_db["demo"].insert({"id": 1, "name": "Cleo"})
traces = []
with fresh_db.tracer(lambda sql, parameters: traces.append((sql, parameters))):
fresh_db["demo"].create(cols, **kwargs, transform=True)
at_least_one_create_table = any(sql.startswith("CREATE TABLE") for sql, _ in traces)
assert should_transform == at_least_one_create_table
new_schema = fresh_db["demo"].schema
assert new_schema == expected_schema, repr(new_schema)
assert fresh_db["demo"].count == 1
def test_rename_table(fresh_db):
fresh_db["t"].insert({"foo": "bar"})
assert ["t"] == fresh_db.table_names()
fresh_db.rename_table("t", "renamed")
assert ["renamed"] == fresh_db.table_names()
assert [{"foo": "bar"}] == list(fresh_db["renamed"].rows)
# Should error if table does not exist:
with pytest.raises(sqlite3.OperationalError):
fresh_db.rename_table("does_not_exist", "renamed")
@pytest.mark.parametrize("strict", (False, True))
def test_database_strict(strict):
db = Database(memory=True, strict=strict)
table = db.table("t", columns={"id": int})
table.insert({"id": 1})
assert table.strict == strict or not db.supports_strict
@pytest.mark.parametrize("strict", (False, True))
def test_database_strict_override(strict):
db = Database(memory=True, strict=strict)
table = db.table("t", columns={"id": int}, strict=not strict)
table.insert({"id": 1})
assert table.strict != strict or not db.supports_strict
@pytest.mark.parametrize(
"method_name", ("insert", "upsert", "insert_all", "upsert_all")
)
@pytest.mark.parametrize("strict", (False, True))
def test_insert_upsert_strict(fresh_db, method_name, strict):
table = fresh_db["t"]
method = getattr(table, method_name)
record = {"id": 1}
if method_name.endswith("_all"):
record = [record]
method(record, pk="id", strict=strict)
assert table.strict == strict or not fresh_db.supports_strict
@pytest.mark.parametrize("strict", (False, True))
def test_create_table_strict(fresh_db, strict):
table = fresh_db.create_table("t", {"id": int, "f": float}, strict=strict)
assert table.strict == strict or not fresh_db.supports_strict
expected_schema = "CREATE TABLE [t] (\n" " [id] INTEGER,\n" " [f] FLOAT\n" ")"
if strict and not fresh_db.supports_strict:
return
if strict:
expected_schema = "CREATE TABLE [t] (\n [id] INTEGER,\n [f] REAL\n) STRICT"
assert table.schema == expected_schema
@pytest.mark.parametrize("strict", (False, True))
def test_create_strict(fresh_db, strict):
table = fresh_db["t"]
table.create({"id": int}, strict=strict)
assert table.strict == strict or not fresh_db.supports_strict
</document_content>
</document>
<document index="27">
<source>./tests/test_create_view.py</source>
<document_content>
import pytest
from sqlite_utils.utils import OperationalError
def test_create_view(fresh_db):
fresh_db.create_view("bar", "select 1 + 1")
rows = fresh_db.execute("select * from bar").fetchall()
assert [(2,)] == rows
def test_create_view_error(fresh_db):
fresh_db.create_view("bar", "select 1 + 1")
with pytest.raises(OperationalError):
fresh_db.create_view("bar", "select 1 + 2")
def test_create_view_only_arrow_one_param(fresh_db):
with pytest.raises(AssertionError):
fresh_db.create_view("bar", "select 1 + 2", ignore=True, replace=True)
def test_create_view_ignore(fresh_db):
fresh_db.create_view("bar", "select 1 + 1").create_view(
"bar", "select 1 + 2", ignore=True
)
rows = fresh_db.execute("select * from bar").fetchall()
assert [(2,)] == rows
def test_create_view_replace(fresh_db):
fresh_db.create_view("bar", "select 1 + 1").create_view(
"bar", "select 1 + 2", replace=True
)
rows = fresh_db.execute("select * from bar").fetchall()
assert [(3,)] == rows
def test_create_view_replace_with_same_does_nothing(fresh_db):
fresh_db.create_view("bar", "select 1 + 1")
initial_version = fresh_db.execute("PRAGMA schema_version").fetchone()[0]
fresh_db.create_view("bar", "select 1 + 1", replace=True)
after_version = fresh_db.execute("PRAGMA schema_version").fetchone()[0]
assert after_version == initial_version
</document_content>
</document>
<document index="28">
<source>./tests/test_default_value.py</source>
<document_content>
import pytest
EXAMPLES = [
("TEXT DEFAULT 'foo'", "'foo'", "'foo'"),
("TEXT DEFAULT 'foo)'", "'foo)'", "'foo)'"),
("INTEGER DEFAULT '1'", "'1'", "'1'"),
("INTEGER DEFAULT 1", "1", "'1'"),
("INTEGER DEFAULT (1)", "1", "'1'"),
# Expressions
(
"TEXT DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))",
"STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')",
"(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))",
),
# Special values
("TEXT DEFAULT CURRENT_TIME", "CURRENT_TIME", "CURRENT_TIME"),
("TEXT DEFAULT CURRENT_DATE", "CURRENT_DATE", "CURRENT_DATE"),
("TEXT DEFAULT CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP"),
("TEXT DEFAULT current_timestamp", "current_timestamp", "current_timestamp"),
("TEXT DEFAULT (CURRENT_TIMESTAMP)", "CURRENT_TIMESTAMP", "CURRENT_TIMESTAMP"),
# Strings
("TEXT DEFAULT 'CURRENT_TIMESTAMP'", "'CURRENT_TIMESTAMP'", "'CURRENT_TIMESTAMP'"),
('TEXT DEFAULT "CURRENT_TIMESTAMP"', '"CURRENT_TIMESTAMP"', '"CURRENT_TIMESTAMP"'),
]
@pytest.mark.parametrize("column_def,initial_value,expected_value", EXAMPLES)
def test_quote_default_value(fresh_db, column_def, initial_value, expected_value):
fresh_db.execute("create table foo (col {})".format(column_def))
assert initial_value == fresh_db["foo"].columns[0].default_value
assert expected_value == fresh_db.quote_default_value(
fresh_db["foo"].columns[0].default_value
)
</document_content>
</document>
<document index="29">
<source>./tests/test_delete.py</source>
<document_content>
def test_delete_rowid_table(fresh_db):
table = fresh_db["table"]
table.insert({"foo": 1}).last_pk
rowid = table.insert({"foo": 2}).last_pk
table.delete(rowid)
assert [{"foo": 1}] == list(table.rows)
def test_delete_pk_table(fresh_db):
table = fresh_db["table"]
table.insert({"id": 1}, pk="id")
table.insert({"id": 2}, pk="id")
table.delete(1)
assert [{"id": 2}] == list(table.rows)
def test_delete_where(fresh_db):
table = fresh_db["table"]
for i in range(1, 11):
table.insert({"id": i}, pk="id")
assert table.count == 10
table.delete_where("id > ?", [5])
assert table.count == 5
def test_delete_where_all(fresh_db):
table = fresh_db["table"]
for i in range(1, 11):
table.insert({"id": i}, pk="id")
assert table.count == 10
table.delete_where()
assert table.count == 0
def test_delete_where_analyze(fresh_db):
table = fresh_db["table"]
table.insert_all(({"id": i, "i": i} for i in range(10)), pk="id")
table.create_index(["i"], analyze=True)
assert "sqlite_stat1" in fresh_db.table_names()
assert list(fresh_db["sqlite_stat1"].rows) == [
{"tbl": "table", "idx": "idx_table_i", "stat": "10 1"}
]
table.delete_where("id > ?", [5], analyze=True)
assert list(fresh_db["sqlite_stat1"].rows) == [
{"tbl": "table", "idx": "idx_table_i", "stat": "6 1"}
]
</document_content>
</document>
<document index="30">
<source>./tests/test_docs.py</source>
<document_content>
from click.testing import CliRunner
from sqlite_utils import cli, recipes
from pathlib import Path
import pytest
import re
docs_path = Path(__file__).parent.parent / "docs"
commands_re = re.compile(r"(?:\$ | )sqlite-utils (\S+)")
recipes_re = re.compile(r"r\.(\w+)\(")
@pytest.fixture(scope="session")
def documented_commands():
rst = ""
for doc in ("cli.rst", "plugins.rst"):
rst += (docs_path / doc).read_text()
return {
command
for command in commands_re.findall(rst)
if "." not in command and ":" not in command
}
@pytest.fixture(scope="session")
def documented_recipes():
rst = (docs_path / "cli.rst").read_text()
return set(recipes_re.findall(rst))
@pytest.mark.parametrize("command", cli.cli.commands.keys())
def test_commands_are_documented(documented_commands, command):
assert command in documented_commands
@pytest.mark.parametrize("command", cli.cli.commands.values())
def test_commands_have_help(command):
assert command.help, "{} is missing its help".format(command)
def test_convert_help():
result = CliRunner().invoke(cli.cli, ["convert", "--help"])
assert result.exit_code == 0
for expected in (
"r.jsonsplit(value, ",
"r.parsedate(value, ",
"r.parsedatetime(value, ",
):
assert expected in result.output
@pytest.mark.parametrize(
"recipe",
[
n
for n in dir(recipes)
if not n.startswith("_")
and n not in ("json", "parser")
and callable(getattr(recipes, n))
],
)
def test_recipes_are_documented(documented_recipes, recipe):
assert recipe in documented_recipes
</document_content>
</document>
<document index="31">
<source>./tests/test_duplicate.py</source>
<document_content>
from sqlite_utils.db import NoTable
import datetime
import pytest
def test_duplicate(fresh_db):
# Create table using native Sqlite statement:
fresh_db.execute(
"""CREATE TABLE [table1] (
[text_col] TEXT,
[real_col] REAL,
[int_col] INTEGER,
[bool_col] INTEGER,
[datetime_col] TEXT)"""
)
# Insert one row of mock data:
dt = datetime.datetime.now()
data = {
"text_col": "Cleo",
"real_col": 3.14,
"int_col": -255,
"bool_col": True,
"datetime_col": str(dt),
}
table1 = fresh_db["table1"]
row_id = table1.insert(data).last_rowid
# Duplicate table:
table2 = table1.duplicate("table2")
# Ensure data integrity:
assert data == table2.get(row_id)
# Ensure schema integrity:
assert [
{"name": "text_col", "type": "TEXT"},
{"name": "real_col", "type": "REAL"},
{"name": "int_col", "type": "INT"},
{"name": "bool_col", "type": "INT"},
{"name": "datetime_col", "type": "TEXT"},
] == [{"name": col.name, "type": col.type} for col in table2.columns]
def test_duplicate_fails_if_table_does_not_exist(fresh_db):
with pytest.raises(NoTable):
fresh_db["not_a_table"].duplicate("duplicated")
</document_content>
</document>
<document index="32">
<source>./tests/test_enable_counts.py</source>
<document_content>
from sqlite_utils import Database
from sqlite_utils import cli
from click.testing import CliRunner
import pytest
def test_enable_counts_specific_table(fresh_db):
foo = fresh_db["foo"]
assert fresh_db.table_names() == []
for i in range(10):
foo.insert({"name": "item {}".format(i)})
assert fresh_db.table_names() == ["foo"]
assert foo.count == 10
# Now enable counts
foo.enable_counts()
assert foo.triggers_dict == {
"foo_counts_insert": (
"CREATE TRIGGER [foo_counts_insert] AFTER INSERT ON [foo]\n"
"BEGIN\n"
" INSERT OR REPLACE INTO [_counts]\n"
" VALUES (\n 'foo',\n"
" COALESCE(\n"
" (SELECT count FROM [_counts] WHERE [table] = 'foo'),\n"
" 0\n"
" ) + 1\n"
" );\n"
"END"
),
"foo_counts_delete": (
"CREATE TRIGGER [foo_counts_delete] AFTER DELETE ON [foo]\n"
"BEGIN\n"
" INSERT OR REPLACE INTO [_counts]\n"
" VALUES (\n"
" 'foo',\n"
" COALESCE(\n"
" (SELECT count FROM [_counts] WHERE [table] = 'foo'),\n"
" 0\n"
" ) - 1\n"
" );\n"
"END"
),
}
assert fresh_db.table_names() == ["foo", "_counts"]
assert list(fresh_db["_counts"].rows) == [{"count": 10, "table": "foo"}]
# Add some items to test the triggers
for i in range(5):
foo.insert({"name": "item {}".format(10 + i)})
assert foo.count == 15
assert list(fresh_db["_counts"].rows) == [{"count": 15, "table": "foo"}]
# Delete some items
foo.delete_where("rowid < 7")
assert foo.count == 9
assert list(fresh_db["_counts"].rows) == [{"count": 9, "table": "foo"}]
foo.delete_where()
assert foo.count == 0
assert list(fresh_db["_counts"].rows) == [{"count": 0, "table": "foo"}]
def test_enable_counts_all_tables(fresh_db):
foo = fresh_db["foo"]
bar = fresh_db["bar"]
foo.insert({"name": "Cleo"})
bar.insert({"name": "Cleo"})
foo.enable_fts(["name"])
fresh_db.enable_counts()
assert set(fresh_db.table_names()) == {
"foo",
"bar",
"foo_fts",
"foo_fts_data",
"foo_fts_idx",
"foo_fts_docsize",
"foo_fts_config",
"_counts",
}
assert list(fresh_db["_counts"].rows) == [
{"count": 1, "table": "foo"},
{"count": 1, "table": "bar"},
{"count": 3, "table": "foo_fts_data"},
{"count": 1, "table": "foo_fts_idx"},
{"count": 1, "table": "foo_fts_docsize"},
{"count": 1, "table": "foo_fts_config"},
]
@pytest.fixture
def counts_db_path(tmpdir):
path = str(tmpdir / "test.db")
db = Database(path)
db["foo"].insert({"name": "bar"})
db["bar"].insert({"name": "bar"})
db["bar"].insert({"name": "bar"})
db["baz"].insert({"name": "bar"})
return path
@pytest.mark.parametrize(
"extra_args,expected_triggers",
[
(
[],
[
"foo_counts_insert",
"foo_counts_delete",
"bar_counts_insert",
"bar_counts_delete",
"baz_counts_insert",
"baz_counts_delete",
],
),
(
["bar"],
[
"bar_counts_insert",
"bar_counts_delete",
],
),
],
)
def test_cli_enable_counts(counts_db_path, extra_args, expected_triggers):
db = Database(counts_db_path)
assert list(db.triggers_dict.keys()) == []
result = CliRunner().invoke(cli.cli, ["enable-counts", counts_db_path] + extra_args)
assert result.exit_code == 0
assert list(db.triggers_dict.keys()) == expected_triggers
def test_uses_counts_after_enable_counts(counts_db_path):
db = Database(counts_db_path)
logged = []
with db.tracer(lambda sql, parameters: logged.append((sql, parameters))):
assert db["foo"].count == 1
assert logged == [
("select name from sqlite_master where type = 'view'", None),
("select count(*) from [foo]", []),
]
logged.clear()
assert not db.use_counts_table
db.enable_counts()
assert db.use_counts_table
assert db["foo"].count == 1
assert logged == [
(
"CREATE TABLE IF NOT EXISTS [_counts](\n [table] TEXT PRIMARY KEY,\n count INTEGER DEFAULT 0\n);",
None,
),
("select name from sqlite_master where type = 'table'", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'view'", None),
("select sql from sqlite_master where name = ?", ("foo",)),
("SELECT quote(:value)", {"value": "foo"}),
("select sql from sqlite_master where name = ?", ("bar",)),
("SELECT quote(:value)", {"value": "bar"}),
("select sql from sqlite_master where name = ?", ("baz",)),
("SELECT quote(:value)", {"value": "baz"}),
("select sql from sqlite_master where name = ?", ("_counts",)),
("select name from sqlite_master where type = 'view'", None),
("select [table], count from _counts where [table] in (?)", ["foo"]),
]
def test_reset_counts(counts_db_path):
db = Database(counts_db_path)
db["foo"].enable_counts()
db["bar"].enable_counts()
assert db.cached_counts() == {"foo": 1, "bar": 2}
# Corrupt the value
db["_counts"].update("foo", {"count": 3})
assert db.cached_counts() == {"foo": 3, "bar": 2}
assert db["foo"].count == 3
# Reset them
db.reset_counts()
assert db.cached_counts() == {"foo": 1, "bar": 2}
assert db["foo"].count == 1
def test_reset_counts_cli(counts_db_path):
db = Database(counts_db_path)
db["foo"].enable_counts()
db["bar"].enable_counts()
assert db.cached_counts() == {"foo": 1, "bar": 2}
db["_counts"].update("foo", {"count": 3})
result = CliRunner().invoke(cli.cli, ["reset-counts", counts_db_path])
assert result.exit_code == 0
assert db.cached_counts() == {"foo": 1, "bar": 2}
</document_content>
</document>
<document index="33">
<source>./tests/test_extract.py</source>
<document_content>
from sqlite_utils.db import InvalidColumns
import itertools
import pytest
@pytest.mark.parametrize("table", [None, "Species"])
@pytest.mark.parametrize("fk_column", [None, "species"])
def test_extract_single_column(fresh_db, table, fk_column):
expected_table = table or "species"
expected_fk = fk_column or "{}_id".format(expected_table)
iter_species = itertools.cycle(["Palm", "Spruce", "Mangrove", "Oak"])
fresh_db["tree"].insert_all(
(
{
"id": i,
"name": "Tree {}".format(i),
"species": next(iter_species),
"end": 1,
}
for i in range(1, 1001)
),
pk="id",
)
fresh_db["tree"].extract("species", table=table, fk_column=fk_column)
assert fresh_db["tree"].schema == (
'CREATE TABLE "tree" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [{}] INTEGER REFERENCES [{}]([id]),\n".format(expected_fk, expected_table)
+ " [end] INTEGER\n"
+ ")"
)
assert fresh_db[expected_table].schema == (
"CREATE TABLE [{}] (\n".format(expected_table)
+ " [id] INTEGER PRIMARY KEY,\n"
" [species] TEXT\n"
")"
)
assert list(fresh_db[expected_table].rows) == [
{"id": 1, "species": "Palm"},
{"id": 2, "species": "Spruce"},
{"id": 3, "species": "Mangrove"},
{"id": 4, "species": "Oak"},
]
assert list(itertools.islice(fresh_db["tree"].rows, 0, 4)) == [
{"id": 1, "name": "Tree 1", expected_fk: 1, "end": 1},
{"id": 2, "name": "Tree 2", expected_fk: 2, "end": 1},
{"id": 3, "name": "Tree 3", expected_fk: 3, "end": 1},
{"id": 4, "name": "Tree 4", expected_fk: 4, "end": 1},
]
def test_extract_multiple_columns_with_rename(fresh_db):
iter_common = itertools.cycle(["Palm", "Spruce", "Mangrove", "Oak"])
iter_latin = itertools.cycle(["Arecaceae", "Picea", "Rhizophora", "Quercus"])
fresh_db["tree"].insert_all(
(
{
"id": i,
"name": "Tree {}".format(i),
"common_name": next(iter_common),
"latin_name": next(iter_latin),
}
for i in range(1, 1001)
),
pk="id",
)
fresh_db["tree"].extract(
["common_name", "latin_name"], rename={"common_name": "name"}
)
assert fresh_db["tree"].schema == (
'CREATE TABLE "tree" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [common_name_latin_name_id] INTEGER REFERENCES [common_name_latin_name]([id])\n"
")"
)
assert fresh_db["common_name_latin_name"].schema == (
"CREATE TABLE [common_name_latin_name] (\n"
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [latin_name] TEXT\n"
")"
)
assert list(fresh_db["common_name_latin_name"].rows) == [
{"name": "Palm", "id": 1, "latin_name": "Arecaceae"},
{"name": "Spruce", "id": 2, "latin_name": "Picea"},
{"name": "Mangrove", "id": 3, "latin_name": "Rhizophora"},
{"name": "Oak", "id": 4, "latin_name": "Quercus"},
]
assert list(itertools.islice(fresh_db["tree"].rows, 0, 4)) == [
{"id": 1, "name": "Tree 1", "common_name_latin_name_id": 1},
{"id": 2, "name": "Tree 2", "common_name_latin_name_id": 2},
{"id": 3, "name": "Tree 3", "common_name_latin_name_id": 3},
{"id": 4, "name": "Tree 4", "common_name_latin_name_id": 4},
]
def test_extract_invalid_columns(fresh_db):
fresh_db["tree"].insert(
{
"id": 1,
"name": "Tree 1",
"common_name": "Palm",
"latin_name": "Arecaceae",
},
pk="id",
)
with pytest.raises(InvalidColumns):
fresh_db["tree"].extract(["bad_column"])
def test_extract_rowid_table(fresh_db):
fresh_db["tree"].insert(
{
"name": "Tree 1",
"common_name": "Palm",
"latin_name": "Arecaceae",
}
)
fresh_db["tree"].extract(["common_name", "latin_name"])
assert fresh_db["tree"].schema == (
'CREATE TABLE "tree" (\n'
" [name] TEXT,\n"
" [common_name_latin_name_id] INTEGER REFERENCES [common_name_latin_name]([id])\n"
")"
)
assert (
fresh_db.execute(
"""
select
tree.name,
common_name_latin_name.common_name,
common_name_latin_name.latin_name
from tree
join common_name_latin_name
on tree.common_name_latin_name_id = common_name_latin_name.id
"""
).fetchall()
== [("Tree 1", "Palm", "Arecaceae")]
)
def test_reuse_lookup_table(fresh_db):
fresh_db["species"].insert({"id": 1, "name": "Wolf"}, pk="id")
fresh_db["sightings"].insert({"id": 10, "species": "Wolf"}, pk="id")
fresh_db["individuals"].insert(
{"id": 10, "name": "Terriana", "species": "Fox"}, pk="id"
)
fresh_db["sightings"].extract("species", rename={"species": "name"})
fresh_db["individuals"].extract("species", rename={"species": "name"})
assert fresh_db["sightings"].schema == (
'CREATE TABLE "sightings" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [species_id] INTEGER REFERENCES [species]([id])\n"
")"
)
assert fresh_db["individuals"].schema == (
'CREATE TABLE "individuals" (\n'
" [id] INTEGER PRIMARY KEY,\n"
" [name] TEXT,\n"
" [species_id] INTEGER REFERENCES [species]([id])\n"
")"
)
assert list(fresh_db["species"].rows) == [
{"id": 1, "name": "Wolf"},
{"id": 2, "name": "Fox"},
]
def test_extract_error_on_incompatible_existing_lookup_table(fresh_db):
fresh_db["species"].insert({"id": 1})
fresh_db["tree"].insert({"name": "Tree 1", "common_name": "Palm"})
with pytest.raises(InvalidColumns):
fresh_db["tree"].extract("common_name", table="species")
# Try again with incompatible existing column type
fresh_db["species2"].insert({"id": 1, "common_name": 3.5})
with pytest.raises(InvalidColumns):
fresh_db["tree"].extract("common_name", table="species2")
def test_extract_works_with_null_values(fresh_db):
fresh_db["listens"].insert_all(
[
{"id": 1, "track_title": "foo", "album_title": "bar"},
{"id": 2, "track_title": "baz", "album_title": None},
],
pk="id",
)
fresh_db["listens"].extract(
columns=["album_title"], table="albums", fk_column="album_id"
)
assert list(fresh_db["listens"].rows) == [
{"id": 1, "track_title": "foo", "album_id": 1},
{"id": 2, "track_title": "baz", "album_id": 2},
]
assert list(fresh_db["albums"].rows) == [
{"id": 1, "album_title": "bar"},
{"id": 2, "album_title": None},
]
</document_content>
</document>
<document index="34">
<source>./tests/test_extracts.py</source>
<document_content>
from sqlite_utils.db import Index
import pytest
@pytest.mark.parametrize(
"kwargs,expected_table",
[
(dict(extracts={"species_id": "Species"}), "Species"),
(dict(extracts=["species_id"]), "species_id"),
(dict(extracts=("species_id",)), "species_id"),
],
)
@pytest.mark.parametrize("use_table_factory", [True, False])
def test_extracts(fresh_db, kwargs, expected_table, use_table_factory):
table_kwargs = {}
insert_kwargs = {}
if use_table_factory:
table_kwargs = kwargs
else:
insert_kwargs = kwargs
trees = fresh_db.table("Trees", **table_kwargs)
trees.insert_all(
[
{"id": 1, "species_id": "Oak"},
{"id": 2, "species_id": "Oak"},
{"id": 3, "species_id": "Palm"},
],
**insert_kwargs
)
# Should now have two tables: Trees and Species
assert {expected_table, "Trees"} == set(fresh_db.table_names())
assert (
"CREATE TABLE [{}] (\n [id] INTEGER PRIMARY KEY,\n [value] TEXT\n)".format(
expected_table
)
== fresh_db[expected_table].schema
)
assert (
"CREATE TABLE [Trees] (\n [id] INTEGER,\n [species_id] INTEGER REFERENCES [{}]([id])\n)".format(
expected_table
)
== fresh_db["Trees"].schema
)
# Should have a foreign key reference
assert len(fresh_db["Trees"].foreign_keys) == 1
fk = fresh_db["Trees"].foreign_keys[0]
assert fk.table == "Trees"
assert fk.column == "species_id"
# Should have unique index on Species
assert [
Index(
seq=0,
name="idx_{}_value".format(expected_table),
unique=1,
origin="c",
partial=0,
columns=["value"],
)
] == fresh_db[expected_table].indexes
# Finally, check the rows
assert [{"id": 1, "value": "Oak"}, {"id": 2, "value": "Palm"}] == list(
fresh_db[expected_table].rows
)
assert [
{"id": 1, "species_id": 1},
{"id": 2, "species_id": 1},
{"id": 3, "species_id": 2},
] == list(fresh_db["Trees"].rows)
</document_content>
</document>
<document index="35">
<source>./tests/test_fts.py</source>
<document_content>
import pytest
from sqlite_utils import Database
from sqlite_utils.utils import sqlite3
from unittest.mock import ANY
search_records = [
{
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
},
{
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
},
]
def test_enable_fts(fresh_db):
table = fresh_db["searchable"]
table.insert_all(search_records)
assert ["searchable"] == fresh_db.table_names()
table.enable_fts(["text", "country"], fts_version="FTS4")
assert [
"searchable",
"searchable_fts",
"searchable_fts_segments",
"searchable_fts_segdir",
"searchable_fts_docsize",
"searchable_fts_stat",
] == fresh_db.table_names()
assert [
{
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
}
] == list(table.search("tanuki"))
assert [
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}
] == list(table.search("usa"))
assert [] == list(table.search("bar"))
def test_enable_fts_escape_table_names(fresh_db):
# Table names with restricted chars are handled correctly.
# colons and dots are restricted characters for table names.
table = fresh_db["http://example.com"]
table.insert_all(search_records)
assert ["http://example.com"] == fresh_db.table_names()
table.enable_fts(["text", "country"], fts_version="FTS4")
assert [
"http://example.com",
"http://example.com_fts",
"http://example.com_fts_segments",
"http://example.com_fts_segdir",
"http://example.com_fts_docsize",
"http://example.com_fts_stat",
] == fresh_db.table_names()
assert [
{
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
}
] == list(table.search("tanuki"))
assert [
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}
] == list(table.search("usa"))
assert [] == list(table.search("bar"))
def test_search_limit_offset(fresh_db):
table = fresh_db["t"]
table.insert_all(search_records)
table.enable_fts(["text", "country"], fts_version="FTS4")
assert len(list(table.search("are"))) == 2
assert len(list(table.search("are", limit=1))) == 1
assert list(table.search("are", limit=1, order_by="rowid"))[0]["rowid"] == 1
assert (
list(table.search("are", limit=1, offset=1, order_by="rowid"))[0]["rowid"] == 2
)
@pytest.mark.parametrize("fts_version", ("FTS4", "FTS5"))
def test_search_where(fresh_db, fts_version):
table = fresh_db["t"]
table.insert_all(search_records)
table.enable_fts(["text", "country"], fts_version=fts_version)
results = list(
table.search("are", where="country = :country", where_args={"country": "Japan"})
)
assert results == [
{
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
}
]
def test_search_where_args_disallows_query(fresh_db):
table = fresh_db["t"]
with pytest.raises(ValueError) as ex:
list(
table.search(
"x", where="author = :query", where_args={"query": "not allowed"}
)
)
assert (
ex.value.args[0]
== "'query' is a reserved key and cannot be passed to where_args for .search()"
)
def test_search_include_rank(fresh_db):
table = fresh_db["t"]
table.insert_all(search_records)
table.enable_fts(["text", "country"], fts_version="FTS5")
results = list(table.search("are", include_rank=True))
assert results == [
{
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
"rank": ANY,
},
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
"rank": ANY,
},
]
assert isinstance(results[0]["rank"], float)
assert isinstance(results[1]["rank"], float)
assert results[0]["rank"] < results[1]["rank"]
def test_enable_fts_table_names_containing_spaces(fresh_db):
table = fresh_db["test"]
table.insert({"column with spaces": "in its name"})
table.enable_fts(["column with spaces"])
assert [
"test",
"test_fts",
"test_fts_data",
"test_fts_idx",
"test_fts_docsize",
"test_fts_config",
] == fresh_db.table_names()
def test_populate_fts(fresh_db):
table = fresh_db["populatable"]
table.insert(search_records[0])
table.enable_fts(["text", "country"], fts_version="FTS4")
assert [] == list(table.search("trash pandas"))
table.insert(search_records[1])
assert [] == list(table.search("trash pandas"))
# Now run populate_fts to make this record available
table.populate_fts(["text", "country"])
rows = list(table.search("usa"))
assert [
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}
] == rows
def test_populate_fts_escape_table_names(fresh_db):
# Restricted characters such as colon and dots should be escaped.
table = fresh_db["http://example.com"]
table.insert(search_records[0])
table.enable_fts(["text", "country"], fts_version="FTS4")
assert [] == list(table.search("trash pandas"))
table.insert(search_records[1])
assert [] == list(table.search("trash pandas"))
# Now run populate_fts to make this record available
table.populate_fts(["text", "country"])
assert [
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}
] == list(table.search("usa"))
@pytest.mark.parametrize("fts_version", ("4", "5"))
def test_fts_tokenize(fresh_db, fts_version):
table_name = "searchable_{}".format(fts_version)
table = fresh_db[table_name]
table.insert_all(search_records)
# Test without porter stemming
table.enable_fts(
["text", "country"],
fts_version="FTS{}".format(fts_version),
)
assert [] == list(table.search("bite"))
# Test WITH stemming
table.disable_fts()
table.enable_fts(
["text", "country"],
fts_version="FTS{}".format(fts_version),
tokenize="porter",
)
rows = list(table.search("bite", order_by="rowid"))
assert len(rows) == 1
assert {
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}.items() <= rows[0].items()
def test_optimize_fts(fresh_db):
for fts_version in ("4", "5"):
table_name = "searchable_{}".format(fts_version)
table = fresh_db[table_name]
table.insert_all(search_records)
table.enable_fts(["text", "country"], fts_version="FTS{}".format(fts_version))
# You can call optimize successfully against the tables OR their _fts equivalents:
for table_name in (
"searchable_4",
"searchable_5",
"searchable_4_fts",
"searchable_5_fts",
):
fresh_db[table_name].optimize()
def test_enable_fts_with_triggers(fresh_db):
table = fresh_db["searchable"]
table.insert(search_records[0])
table.enable_fts(["text", "country"], fts_version="FTS4", create_triggers=True)
rows1 = list(table.search("tanuki"))
assert len(rows1) == 1
assert rows1 == [
{
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
}
]
table.insert(search_records[1])
# Triggers will auto-populate FTS virtual table, not need to call populate_fts()
rows2 = list(table.search("usa"))
assert rows2 == [
{
"rowid": 2,
"text": "racoons are biting trash pandas",
"country": "USA",
"not_searchable": "bar",
}
]
assert [] == list(table.search("bar"))
@pytest.mark.parametrize("create_triggers", [True, False])
def test_disable_fts(fresh_db, create_triggers):
table = fresh_db["searchable"]
table.insert(search_records[0])
table.enable_fts(["text", "country"], create_triggers=create_triggers)
assert {
"searchable",
"searchable_fts",
"searchable_fts_data",
"searchable_fts_idx",
"searchable_fts_docsize",
"searchable_fts_config",
} == set(fresh_db.table_names())
if create_triggers:
expected_triggers = {"searchable_ai", "searchable_ad", "searchable_au"}
else:
expected_triggers = set()
assert expected_triggers == set(
r[0]
for r in fresh_db.execute(
"select name from sqlite_master where type = 'trigger'"
).fetchall()
)
# Now run .disable_fts() and confirm it worked
table.disable_fts()
assert (
0
== fresh_db.execute(
"select count(*) from sqlite_master where type = 'trigger'"
).fetchone()[0]
)
assert ["searchable"] == fresh_db.table_names()
def test_rebuild_fts(fresh_db):
table = fresh_db["searchable"]
table.insert(search_records[0])
table.enable_fts(["text", "country"])
# Run a search
rows = list(table.search("are"))
assert len(rows) == 1
assert {
"rowid": 1,
"text": "tanuki are running tricksters",
"country": "Japan",
"not_searchable": "foo",
}.items() <= rows[0].items()
# Insert another record
table.insert(search_records[1])
# This should NOT show up in searches
assert len(list(table.search("are"))) == 1
# Running rebuild_fts() should fix it
table.rebuild_fts()
rows2 = list(table.search("are"))
assert len(rows2) == 2
@pytest.mark.parametrize("invalid_table", ["does_not_exist", "not_searchable"])
def test_rebuild_fts_invalid(fresh_db, invalid_table):
fresh_db["not_searchable"].insert({"foo": "bar"})
# Raise OperationalError on invalid table
with pytest.raises(sqlite3.OperationalError):
fresh_db[invalid_table].rebuild_fts()
@pytest.mark.parametrize("fts_version", ["FTS4", "FTS5"])
def test_rebuild_removes_junk_docsize_rows(tmpdir, fts_version):
# Recreating https://github.com/simonw/sqlite-utils/issues/149
path = tmpdir / "test.db"
db = Database(str(path), recursive_triggers=False)
licenses = [{"key": "apache2", "name": "Apache 2"}, {"key": "bsd", "name": "BSD"}]
db["licenses"].insert_all(licenses, pk="key", replace=True)
db["licenses"].enable_fts(["name"], create_triggers=True, fts_version=fts_version)
assert db["licenses_fts_docsize"].count == 2
# Bug: insert with replace increases the number of rows in _docsize:
db["licenses"].insert_all(licenses, pk="key", replace=True)
assert db["licenses_fts_docsize"].count == 4
# rebuild should fix this:
db["licenses_fts"].rebuild_fts()
assert db["licenses_fts_docsize"].count == 2
@pytest.mark.parametrize(
"kwargs",
[
{"columns": ["title"]},
{"fts_version": "FTS4"},
{"create_triggers": True},
{"tokenize": "porter"},
],
)
def test_enable_fts_replace(kwargs):
db = Database(memory=True)
db["books"].insert(
{
"id": 1,
"title": "Habits of Australian Marsupials",
"author": "Marlee Hawkins",
},
pk="id",
)
db["books"].enable_fts(["title", "author"])
assert not db["books"].triggers
assert db["books_fts"].columns_dict.keys() == {"title", "author"}
assert "FTS5" in db["books_fts"].schema
assert "porter" not in db["books_fts"].schema
# Now modify the FTS configuration
should_have_changed_columns = "columns" in kwargs
if "columns" not in kwargs:
kwargs["columns"] = ["title", "author"]
db["books"].enable_fts(**kwargs, replace=True)
# Check that the new configuration is correct
if should_have_changed_columns:
assert db["books_fts"].columns_dict.keys() == set(["title"])
if "create_triggers" in kwargs:
assert db["books"].triggers
if "fts_version" in kwargs:
assert "FTS4" in db["books_fts"].schema
if "tokenize" in kwargs:
assert "porter" in db["books_fts"].schema
def test_enable_fts_replace_does_nothing_if_args_the_same():
queries = []
db = Database(memory=True, tracer=lambda sql, params: queries.append((sql, params)))
db["books"].insert(
{
"id": 1,
"title": "Habits of Australian Marsupials",
"author": "Marlee Hawkins",
},
pk="id",
)
db["books"].enable_fts(["title", "author"], create_triggers=True)
queries.clear()
# Running that again shouldn't run much SQL:
db["books"].enable_fts(["title", "author"], create_triggers=True, replace=True)
# The only SQL that executed should be select statements
assert all(q[0].startswith("select ") for q in queries)
def test_enable_fts_error_message_on_views():
db = Database(memory=True)
db.create_view("hello", "select 1 + 1")
with pytest.raises(NotImplementedError) as e:
db["hello"].enable_fts()
assert e.value.args[0] == "enable_fts() is supported on tables but not on views"
@pytest.mark.parametrize(
"kwargs,fts,expected",
[
(
{},
"FTS5",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" [books_fts].rank"
),
),
(
{"columns": ["title"], "order_by": "rowid", "limit": 10},
"FTS5",
(
"with original as (\n"
" select\n"
" rowid,\n"
" [title]\n"
" from [books]\n"
")\n"
"select\n"
" [original].[title]\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rowid\n"
"limit 10"
),
),
(
{"where": "author = :author"},
"FTS5",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
" where author = :author\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" [books_fts].rank"
),
),
(
{"columns": ["title"]},
"FTS4",
(
"with original as (\n"
" select\n"
" rowid,\n"
" [title]\n"
" from [books]\n"
")\n"
"select\n"
" [original].[title]\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx'))"
),
),
(
{"offset": 1, "limit": 1},
"FTS4",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx'))\n"
"limit 1 offset 1"
),
),
(
{"limit": 2},
"FTS4",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx'))\n"
"limit 2"
),
),
(
{"where": "author = :author"},
"FTS4",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
" where author = :author\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx'))"
),
),
(
{"include_rank": True},
"FTS5",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
")\n"
"select\n"
" [original].*,\n"
" [books_fts].rank rank\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" [books_fts].rank"
),
),
(
{"include_rank": True},
"FTS4",
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [books]\n"
")\n"
"select\n"
" [original].*,\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx')) rank\n"
"from\n"
" [original]\n"
" join [books_fts] on [original].rowid = [books_fts].rowid\n"
"where\n"
" [books_fts] match :query\n"
"order by\n"
" rank_bm25(matchinfo([books_fts], 'pcnalx'))"
),
),
],
)
def test_search_sql(kwargs, fts, expected):
db = Database(memory=True)
db["books"].insert(
{
"title": "Habits of Australian Marsupials",
"author": "Marlee Hawkins",
}
)
db["books"].enable_fts(["title", "author"], fts_version=fts)
sql = db["books"].search_sql(**kwargs)
assert sql == expected
@pytest.mark.parametrize(
"input,expected",
(
("dog", '"dog"'),
("cat,", '"cat,"'),
("cat's", '"cat\'s"'),
("dog.", '"dog."'),
("cat dog", '"cat" "dog"'),
# If a phrase is already double quoted, leave it so
('"cat dog"', '"cat dog"'),
('"cat dog" fish', '"cat dog" "fish"'),
# Sensibly handle unbalanced double quotes
('cat"', '"cat"'),
('"cat dog" "fish', '"cat dog" "fish"'),
),
)
def test_quote_fts_query(fresh_db, input, expected):
table = fresh_db["searchable"]
table.insert_all(search_records)
table.enable_fts(["text", "country"])
quoted = fresh_db.quote_fts(input)
assert quoted == expected
# Executing query does not crash.
list(table.search(quoted))
def test_search_quote(fresh_db):
table = fresh_db["searchable"]
table.insert_all(search_records)
table.enable_fts(["text", "country"])
query = "cat's"
with pytest.raises(sqlite3.OperationalError):
list(table.search(query))
# No exception with quote=True
list(table.search(query, quote=True))
</document_content>
</document>
<document index="36">
<source>./tests/test_get.py</source>
<document_content>
import pytest
from sqlite_utils.db import NotFoundError
def test_get_rowid(fresh_db):
dogs = fresh_db["dogs"]
cleo = {"name": "Cleo", "age": 4}
row_id = dogs.insert(cleo).last_rowid
assert cleo == dogs.get(row_id)
def test_get_primary_key(fresh_db):
dogs = fresh_db["dogs"]
cleo = {"name": "Cleo", "age": 4, "id": 5}
last_pk = dogs.insert(cleo, pk="id").last_pk
assert 5 == last_pk
assert cleo == dogs.get(5)
@pytest.mark.parametrize(
"argument,expected_msg",
[(100, None), (None, None), ((1, 2), "Need 1 primary key value"), ("2", None)],
)
def test_get_not_found(argument, expected_msg, fresh_db):
fresh_db["dogs"].insert(
{"id": 1, "name": "Cleo", "age": 4, "is_good": True}, pk="id"
)
with pytest.raises(NotFoundError) as excinfo:
fresh_db["dogs"].get(argument)
if expected_msg is not None:
assert expected_msg == excinfo.value.args[0]
</document_content>
</document>
<document index="37">
<source>./tests/test_gis.py</source>
<document_content>
import json
import pytest
from click.testing import CliRunner
from sqlite_utils.cli import cli
from sqlite_utils.db import Database
from sqlite_utils.utils import find_spatialite, sqlite3
try:
import sqlean
except ImportError:
sqlean = None
pytestmark = [
pytest.mark.skipif(
not find_spatialite(), reason="Could not find SpatiaLite extension"
),
pytest.mark.skipif(
not hasattr(sqlite3.Connection, "enable_load_extension"),
reason="sqlite3.Connection missing enable_load_extension",
),
pytest.mark.skipif(
sqlean is not None, reason="sqlean.py is not compatible with SpatiaLite"
),
]
# python API tests
def test_find_spatialite():
spatialite = find_spatialite()
assert spatialite is None or isinstance(spatialite, str)
def test_init_spatialite():
db = Database(memory=True)
spatialite = find_spatialite()
db.init_spatialite(spatialite)
assert "spatial_ref_sys" in db.table_names()
def test_add_geometry_column():
db = Database(memory=True)
spatialite = find_spatialite()
db.init_spatialite(spatialite)
# create a table first
table = db.create_table("locations", {"id": str, "properties": str})
table.add_geometry_column(
column_name="geometry",
geometry_type="Point",
srid=4326,
coord_dimension=2,
)
assert db["geometry_columns"].get(["locations", "geometry"]) == {
"f_table_name": "locations",
"f_geometry_column": "geometry",
"geometry_type": 1, # point
"coord_dimension": 2,
"srid": 4326,
"spatial_index_enabled": 0,
}
def test_create_spatial_index():
db = Database(memory=True)
spatialite = find_spatialite()
assert db.init_spatialite(spatialite)
# create a table, add a geometry column with default values
table = db.create_table("locations", {"id": str, "properties": str})
assert table.add_geometry_column("geometry", "Point")
# index it
assert table.create_spatial_index("geometry")
assert "idx_locations_geometry" in db.table_names()
def test_double_create_spatial_index():
db = Database(memory=True)
spatialite = find_spatialite()
db.init_spatialite(spatialite)
# create a table, add a geometry column with default values
table = db.create_table("locations", {"id": str, "properties": str})
table.add_geometry_column("geometry", "Point")
# index it, return True
assert table.create_spatial_index("geometry")
assert "idx_locations_geometry" in db.table_names()
# call it again, return False
assert not table.create_spatial_index("geometry")
# cli tests
@pytest.mark.parametrize("use_spatialite_shortcut", [True, False])
def test_query_load_extension(use_spatialite_shortcut):
# Without --load-extension:
result = CliRunner().invoke(cli, [":memory:", "select spatialite_version()"])
assert result.exit_code == 1
assert "no such function: spatialite_version" in result.output
# With --load-extension:
if use_spatialite_shortcut:
load_extension = "spatialite"
else:
load_extension = find_spatialite()
result = CliRunner().invoke(
cli,
[
":memory:",
"select spatialite_version()",
"--load-extension={}".format(load_extension),
],
)
assert result.exit_code == 0, result.stdout
assert ["spatialite_version()"] == list(json.loads(result.output)[0].keys())
def test_cli_create_spatialite(tmpdir):
# sqlite-utils create test.db --init-spatialite
db_path = tmpdir / "created.db"
result = CliRunner().invoke(
cli, ["create-database", str(db_path), "--init-spatialite"]
)
assert result.exit_code == 0
assert db_path.exists()
assert db_path.read_binary()[:16] == b"SQLite format 3\x00"
db = Database(str(db_path))
assert "spatial_ref_sys" in db.table_names()
def test_cli_add_geometry_column(tmpdir):
# create a rowid table with one column
db_path = tmpdir / "spatial.db"
db = Database(str(db_path))
db.init_spatialite()
table = db["locations"].create({"name": str})
result = CliRunner().invoke(
cli,
[
"add-geometry-column",
str(db_path),
table.name,
"geometry",
"--type",
"POINT",
],
)
assert result.exit_code == 0
assert db["geometry_columns"].get(["locations", "geometry"]) == {
"f_table_name": "locations",
"f_geometry_column": "geometry",
"geometry_type": 1, # point
"coord_dimension": 2,
"srid": 4326,
"spatial_index_enabled": 0,
}
def test_cli_add_geometry_column_options(tmpdir):
# create a rowid table with one column
db_path = tmpdir / "spatial.db"
db = Database(str(db_path))
db.init_spatialite()
table = db["locations"].create({"name": str})
result = CliRunner().invoke(
cli,
[
"add-geometry-column",
str(db_path),
table.name,
"geometry",
"-t",
"POLYGON",
"--srid",
"3857", # https://epsg.io/3857
"--not-null",
],
)
assert result.exit_code == 0
assert db["geometry_columns"].get(["locations", "geometry"]) == {
"f_table_name": "locations",
"f_geometry_column": "geometry",
"geometry_type": 3, # polygon
"coord_dimension": 2,
"srid": 3857,
"spatial_index_enabled": 0,
}
column = table.columns[1]
assert column.notnull
def test_cli_add_geometry_column_invalid_type(tmpdir):
# create a rowid table with one column
db_path = tmpdir / "spatial.db"
db = Database(str(db_path))
db.init_spatialite()
table = db["locations"].create({"name": str})
result = CliRunner().invoke(
cli,
[
"add-geometry-column",
str(db_path),
table.name,
"geometry",
"--type",
"NOT-A-TYPE",
],
)
assert 2 == result.exit_code
def test_cli_create_spatial_index(tmpdir):
# create a rowid table with one column
db_path = tmpdir / "spatial.db"
db = Database(str(db_path))
db.init_spatialite()
table = db["locations"].create({"name": str})
table.add_geometry_column("geometry", "POINT")
result = CliRunner().invoke(
cli, ["create-spatial-index", str(db_path), table.name, "geometry"]
)
assert result.exit_code == 0
assert "idx_locations_geometry" in db.table_names()
</document_content>
</document>
<document index="38">
<source>./tests/test_hypothesis.py</source>
<document_content>
from hypothesis import given
import hypothesis.strategies as st
import sqlite_utils
# SQLite integers are -(2^63) to 2^63 - 1
@given(st.integers(-9223372036854775808, 9223372036854775807))
def test_roundtrip_integers(integer):
db = sqlite_utils.Database(memory=True)
row = {
"integer": integer,
}
db["test"].insert(row)
assert list(db["test"].rows) == [row]
@given(st.text())
def test_roundtrip_text(text):
db = sqlite_utils.Database(memory=True)
row = {
"text": text,
}
db["test"].insert(row)
assert list(db["test"].rows) == [row]
@given(st.binary(max_size=1024 * 1024))
def test_roundtrip_binary(binary):
db = sqlite_utils.Database(memory=True)
row = {
"binary": binary,
}
db["test"].insert(row)
assert list(db["test"].rows) == [row]
@given(st.floats(allow_nan=False))
def test_roundtrip_floats(floats):
db = sqlite_utils.Database(memory=True)
row = {
"floats": floats,
}
db["test"].insert(row)
assert list(db["test"].rows) == [row]
</document_content>
</document>
<document index="39">
<source>./tests/test_insert_files.py</source>
<document_content>
from sqlite_utils import cli, Database
from click.testing import CliRunner
import os
import pathlib
import pytest
import sys
@pytest.mark.parametrize("silent", (False, True))
@pytest.mark.parametrize(
"pk_args,expected_pks",
(
(["--pk", "path"], ["path"]),
(["--pk", "path", "--pk", "name"], ["path", "name"]),
),
)
def test_insert_files(silent, pk_args, expected_pks):
runner = CliRunner()
with runner.isolated_filesystem():
tmpdir = pathlib.Path(".")
db_path = str(tmpdir / "files.db")
(tmpdir / "one.txt").write_text("This is file one", "utf-8")
(tmpdir / "two.txt").write_text("Two is shorter", "utf-8")
(tmpdir / "nested").mkdir()
(tmpdir / "nested" / "three.zz.txt").write_text("Three is nested", "utf-8")
coltypes = (
"name",
"path",
"fullpath",
"sha256",
"md5",
"mode",
"content",
"content_text",
"mtime",
"ctime",
"mtime_int",
"ctime_int",
"mtime_iso",
"ctime_iso",
"size",
"suffix",
"stem",
)
cols = []
for coltype in coltypes:
cols += ["-c", "{}:{}".format(coltype, coltype)]
result = runner.invoke(
cli.cli,
["insert-files", db_path, "files", str(tmpdir)]
+ cols
+ pk_args
+ (["--silent"] if silent else []),
catch_exceptions=False,
)
assert result.exit_code == 0, result.stdout
db = Database(db_path)
rows_by_path = {r["path"]: r for r in db["files"].rows}
one, two, three = (
rows_by_path["one.txt"],
rows_by_path["two.txt"],
rows_by_path[os.path.join("nested", "three.zz.txt")],
)
assert {
"content": b"This is file one",
"content_text": "This is file one",
"md5": "556dfb57fce9ca301f914e2273adf354",
"name": "one.txt",
"path": "one.txt",
"sha256": "e34138f26b5f7368f298b4e736fea0aad87ddec69fbd04dc183b20f4d844bad5",
"size": 16,
"stem": "one",
"suffix": ".txt",
}.items() <= one.items()
assert {
"content": b"Two is shorter",
"content_text": "Two is shorter",
"md5": "f86f067b083af1911043eb215e74ac70",
"name": "two.txt",
"path": "two.txt",
"sha256": "9368988ed16d4a2da0af9db9b686d385b942cb3ffd4e013f43aed2ec041183d9",
"size": 14,
"stem": "two",
"suffix": ".txt",
}.items() <= two.items()
assert {
"content": b"Three is nested",
"content_text": "Three is nested",
"md5": "12580f341781f5a5b589164d3cd39523",
"name": "three.zz.txt",
"path": os.path.join("nested", "three.zz.txt"),
"sha256": "6dd45aaaaa6b9f96af19363a92c8fca5d34791d3c35c44eb19468a6a862cc8cd",
"size": 15,
"stem": "three.zz",
"suffix": ".txt",
}.items() <= three.items()
# Assert the other int/str/float columns exist and are of the right types
expected_types = {
"ctime": float,
"ctime_int": int,
"ctime_iso": str,
"mtime": float,
"mtime_int": int,
"mtime_iso": str,
"mode": int,
"fullpath": str,
"content": bytes,
"content_text": str,
"stem": str,
"suffix": str,
}
for colname, expected_type in expected_types.items():
for row in (one, two, three):
assert isinstance(row[colname], expected_type)
assert set(db["files"].pks) == set(expected_pks)
@pytest.mark.parametrize(
"use_text,encoding,input,expected",
(
(False, None, "hello world", b"hello world"),
(True, None, "hello world", "hello world"),
(False, None, b"S\xe3o Paulo", b"S\xe3o Paulo"),
(True, "latin-1", b"S\xe3o Paulo", "S\xe3o Paulo"),
),
)
def test_insert_files_stdin(use_text, encoding, input, expected):
runner = CliRunner()
with runner.isolated_filesystem():
tmpdir = pathlib.Path(".")
db_path = str(tmpdir / "files.db")
args = ["insert-files", db_path, "files", "-", "--name", "stdin-name"]
if use_text:
args += ["--text"]
if encoding is not None:
args += ["--encoding", encoding]
result = runner.invoke(
cli.cli,
args,
catch_exceptions=False,
input=input,
)
assert result.exit_code == 0, result.stdout
db = Database(db_path)
row = list(db["files"].rows)[0]
key = "content"
if use_text:
key = "content_text"
assert {"path": "stdin-name", key: expected}.items() <= row.items()
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="Windows has a different way of handling default encodings",
)
def test_insert_files_bad_text_encoding_error():
runner = CliRunner()
with runner.isolated_filesystem():
tmpdir = pathlib.Path(".")
latin = tmpdir / "latin.txt"
latin.write_bytes(b"S\xe3o Paulo")
db_path = str(tmpdir / "files.db")
result = runner.invoke(
cli.cli,
["insert-files", db_path, "files", str(latin), "--text"],
catch_exceptions=False,
)
assert result.exit_code == 1, result.output
assert result.output.strip().startswith(
"Error: Could not read file '{}' as text".format(str(latin.resolve()))
)
</document_content>
</document>
<document index="40">
<source>./tests/test_introspect.py</source>
<document_content>
from sqlite_utils.db import Index, View, Database, XIndex, XIndexColumn
import pytest
def test_table_names(existing_db):
assert ["foo"] == existing_db.table_names()
def test_view_names(fresh_db):
fresh_db.create_view("foo_view", "select 1")
assert ["foo_view"] == fresh_db.view_names()
def test_table_names_fts4(existing_db):
existing_db["woo"].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS4"
)
existing_db["woo2"].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS5"
)
assert ["woo_fts"] == existing_db.table_names(fts4=True)
assert ["woo2_fts"] == existing_db.table_names(fts5=True)
def test_detect_fts(existing_db):
existing_db["woo"].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS4"
)
existing_db["woo2"].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS5"
)
assert "woo_fts" == existing_db["woo"].detect_fts()
assert "woo_fts" == existing_db["woo_fts"].detect_fts()
assert "woo2_fts" == existing_db["woo2"].detect_fts()
assert "woo2_fts" == existing_db["woo2_fts"].detect_fts()
assert existing_db["foo"].detect_fts() is None
@pytest.mark.parametrize("reverse_order", (True, False))
def test_detect_fts_similar_tables(fresh_db, reverse_order):
# https://github.com/simonw/sqlite-utils/issues/434
table1, table2 = ("demo", "demo2")
if reverse_order:
table1, table2 = table2, table1
fresh_db[table1].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS4"
)
fresh_db[table2].insert({"title": "Hello"}).enable_fts(
["title"], fts_version="FTS4"
)
assert fresh_db[table1].detect_fts() == "{}_fts".format(table1)
assert fresh_db[table2].detect_fts() == "{}_fts".format(table2)
def test_tables(existing_db):
assert len(existing_db.tables) == 1
assert existing_db.tables[0].name == "foo"
def test_views(fresh_db):
fresh_db.create_view("foo_view", "select 1")
assert len(fresh_db.views) == 1
view = fresh_db.views[0]
assert isinstance(view, View)
assert view.name == "foo_view"
assert repr(view) == "<View foo_view (1)>"
assert view.columns_dict == {"1": str}
def test_count(existing_db):
assert existing_db["foo"].count == 3
assert existing_db["foo"].count_where() == 3
assert existing_db["foo"].execute_count() == 3
def test_count_where(existing_db):
assert existing_db["foo"].count_where("text != ?", ["two"]) == 2
assert existing_db["foo"].count_where("text != :t", {"t": "two"}) == 2
def test_columns(existing_db):
table = existing_db["foo"]
assert [{"name": "text", "type": "TEXT"}] == [
{"name": col.name, "type": col.type} for col in table.columns
]
def test_table_schema(existing_db):
assert existing_db["foo"].schema == "CREATE TABLE foo (text TEXT)"
def test_database_schema(existing_db):
assert existing_db.schema == "CREATE TABLE foo (text TEXT);"
def test_table_repr(fresh_db):
table = fresh_db["dogs"].insert({"name": "Cleo", "age": 4})
assert "<Table dogs (name, age)>" == repr(table)
assert "<Table cats (does not exist yet)>" == repr(fresh_db["cats"])
def test_indexes(fresh_db):
fresh_db.executescript(
"""
create table Gosh (c1 text, c2 text, c3 text);
create index Gosh_c1 on Gosh(c1);
create index Gosh_c2c3 on Gosh(c2, c3);
"""
)
assert [
Index(
seq=0,
name="Gosh_c2c3",
unique=0,
origin="c",
partial=0,
columns=["c2", "c3"],
),
Index(seq=1, name="Gosh_c1", unique=0, origin="c", partial=0, columns=["c1"]),
] == fresh_db["Gosh"].indexes
def test_xindexes(fresh_db):
fresh_db.executescript(
"""
create table Gosh (c1 text, c2 text, c3 text);
create index Gosh_c1 on Gosh(c1);
create index Gosh_c2c3 on Gosh(c2, c3 desc);
"""
)
assert fresh_db["Gosh"].xindexes == [
XIndex(
name="Gosh_c2c3",
columns=[
XIndexColumn(seqno=0, cid=1, name="c2", desc=0, coll="BINARY", key=1),
XIndexColumn(seqno=1, cid=2, name="c3", desc=1, coll="BINARY", key=1),
XIndexColumn(seqno=2, cid=-1, name=None, desc=0, coll="BINARY", key=0),
],
),
XIndex(
name="Gosh_c1",
columns=[
XIndexColumn(seqno=0, cid=0, name="c1", desc=0, coll="BINARY", key=1),
XIndexColumn(seqno=1, cid=-1, name=None, desc=0, coll="BINARY", key=0),
],
),
]
@pytest.mark.parametrize(
"column,expected_table_guess",
(
("author", "authors"),
("author_id", "authors"),
("authors", "authors"),
("genre", "genre"),
("genre_id", "genre"),
),
)
def test_guess_foreign_table(fresh_db, column, expected_table_guess):
fresh_db.create_table("authors", {"name": str})
fresh_db.create_table("genre", {"name": str})
assert expected_table_guess == fresh_db["books"].guess_foreign_table(column)
@pytest.mark.parametrize(
"pk,expected", ((None, ["rowid"]), ("id", ["id"]), (["id", "id2"], ["id", "id2"]))
)
def test_pks(fresh_db, pk, expected):
fresh_db["foo"].insert_all([{"id": 1, "id2": 2}], pk=pk)
assert expected == fresh_db["foo"].pks
def test_triggers_and_triggers_dict(fresh_db):
assert [] == fresh_db.triggers
authors = fresh_db["authors"]
authors.insert_all(
[
{"name": "Frank Herbert", "famous_works": "Dune"},
{"name": "Neal Stephenson", "famous_works": "Cryptonomicon"},
]
)
fresh_db["other"].insert({"foo": "bar"})
assert authors.triggers == []
assert authors.triggers_dict == {}
assert fresh_db["other"].triggers == []
assert fresh_db.triggers_dict == {}
authors.enable_fts(
["name", "famous_works"], fts_version="FTS4", create_triggers=True
)
expected_triggers = {
("authors_ai", "authors"),
("authors_ad", "authors"),
("authors_au", "authors"),
}
assert expected_triggers == {(t.name, t.table) for t in fresh_db.triggers}
assert expected_triggers == {
(t.name, t.table) for t in fresh_db["authors"].triggers
}
expected_triggers = {
"authors_ai": (
"CREATE TRIGGER [authors_ai] AFTER INSERT ON [authors] BEGIN\n"
" INSERT INTO [authors_fts] (rowid, [name], [famous_works]) VALUES (new.rowid, new.[name], new.[famous_works]);\n"
"END"
),
"authors_ad": (
"CREATE TRIGGER [authors_ad] AFTER DELETE ON [authors] BEGIN\n"
" INSERT INTO [authors_fts] ([authors_fts], rowid, [name], [famous_works]) VALUES('delete', old.rowid, old.[name], old.[famous_works]);\n"
"END"
),
"authors_au": (
"CREATE TRIGGER [authors_au] AFTER UPDATE ON [authors] BEGIN\n"
" INSERT INTO [authors_fts] ([authors_fts], rowid, [name], [famous_works]) VALUES('delete', old.rowid, old.[name], old.[famous_works]);\n"
" INSERT INTO [authors_fts] (rowid, [name], [famous_works]) VALUES (new.rowid, new.[name], new.[famous_works]);\nEND"
),
}
assert authors.triggers_dict == expected_triggers
assert fresh_db["other"].triggers == []
assert fresh_db["other"].triggers_dict == {}
assert fresh_db.triggers_dict == expected_triggers
def test_has_counts_triggers(fresh_db):
authors = fresh_db["authors"]
authors.insert({"name": "Frank Herbert"})
assert not authors.has_counts_triggers
authors.enable_counts()
assert authors.has_counts_triggers
@pytest.mark.parametrize(
"sql,expected_name,expected_using",
[
(
"""
CREATE VIRTUAL TABLE foo USING FTS5(name)
""",
"foo",
"FTS5",
),
(
"""
CREATE VIRTUAL TABLE "foo" USING FTS4(name)
""",
"foo",
"FTS4",
),
(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS `foo` USING FTS4(name)
""",
"foo",
"FTS4",
),
(
"""
CREATE VIRTUAL TABLE IF NOT EXISTS `foo` USING fts5(name)
""",
"foo",
"FTS5",
),
(
"""
CREATE TABLE IF NOT EXISTS `foo` (id integer primary key)
""",
"foo",
None,
),
],
)
def test_virtual_table_using(fresh_db, sql, expected_name, expected_using):
fresh_db.execute(sql)
assert fresh_db[expected_name].virtual_table_using == expected_using
def test_use_rowid(fresh_db):
fresh_db["rowid_table"].insert({"name": "Cleo"})
fresh_db["regular_table"].insert({"id": 1, "name": "Cleo"}, pk="id")
assert fresh_db["rowid_table"].use_rowid
assert not fresh_db["regular_table"].use_rowid
@pytest.mark.skipif(
not Database(memory=True).supports_strict,
reason="Needs SQLite version that supports strict",
)
@pytest.mark.parametrize(
"create_table,expected_strict",
(
("create table t (id integer) strict", True),
("create table t (id integer) STRICT", True),
("create table t (id integer primary key) StriCt, WITHOUT ROWID", True),
("create table t (id integer primary key) WITHOUT ROWID", False),
("create table t (id integer)", False),
),
)
def test_table_strict(fresh_db, create_table, expected_strict):
fresh_db.execute(create_table)
table = fresh_db["t"]
assert table.strict == expected_strict
@pytest.mark.parametrize(
"value",
(
1,
1.3,
"foo",
True,
b"binary",
),
)
def test_table_default_values(fresh_db, value):
fresh_db["default_values"].insert(
{"nodefault": 1, "value": value}, defaults={"value": value}
)
default_values = fresh_db["default_values"].default_values
assert default_values == {"value": value}
</document_content>
</document>
<document index="41">
<source>./tests/test_lookup.py</source>
<document_content>
from sqlite_utils.db import Index
import pytest
def test_lookup_new_table(fresh_db):
species = fresh_db["species"]
palm_id = species.lookup({"name": "Palm"})
oak_id = species.lookup({"name": "Oak"})
cherry_id = species.lookup({"name": "Cherry"})
assert palm_id == species.lookup({"name": "Palm"})
assert oak_id == species.lookup({"name": "Oak"})
assert cherry_id == species.lookup({"name": "Cherry"})
assert palm_id != oak_id != cherry_id
# Ensure the correct indexes were created
assert [
Index(
seq=0,
name="idx_species_name",
unique=1,
origin="c",
partial=0,
columns=["name"],
)
] == species.indexes
def test_lookup_new_table_compound_key(fresh_db):
species = fresh_db["species"]
palm_id = species.lookup({"name": "Palm", "type": "Tree"})
oak_id = species.lookup({"name": "Oak", "type": "Tree"})
assert palm_id == species.lookup({"name": "Palm", "type": "Tree"})
assert oak_id == species.lookup({"name": "Oak", "type": "Tree"})
assert [
Index(
seq=0,
name="idx_species_name_type",
unique=1,
origin="c",
partial=0,
columns=["name", "type"],
)
] == species.indexes
def test_lookup_adds_unique_constraint_to_existing_table(fresh_db):
species = fresh_db.table("species", pk="id")
palm_id = species.insert({"name": "Palm"}).last_pk
species.insert({"name": "Oak"})
assert [] == species.indexes
assert palm_id == species.lookup({"name": "Palm"})
assert [
Index(
seq=0,
name="idx_species_name",
unique=1,
origin="c",
partial=0,
columns=["name"],
)
] == species.indexes
def test_lookup_fails_if_constraint_cannot_be_added(fresh_db):
species = fresh_db.table("species", pk="id")
species.insert_all([{"id": 1, "name": "Palm"}, {"id": 2, "name": "Palm"}])
# This will fail because the name column is not unique
with pytest.raises(Exception, match="UNIQUE constraint failed"):
species.lookup({"name": "Palm"})
def test_lookup_with_extra_values(fresh_db):
species = fresh_db["species"]
id = species.lookup({"name": "Palm", "type": "Tree"}, {"first_seen": "2020-01-01"})
assert species.get(id) == {
"id": 1,
"name": "Palm",
"type": "Tree",
"first_seen": "2020-01-01",
}
# A subsequent lookup() should ignore the second dictionary
id2 = species.lookup({"name": "Palm", "type": "Tree"}, {"first_seen": "2021-02-02"})
assert id2 == id
assert species.get(id2) == {
"id": 1,
"name": "Palm",
"type": "Tree",
"first_seen": "2020-01-01",
}
def test_lookup_with_extra_insert_parameters(fresh_db):
other_table = fresh_db["other_table"]
other_table.insert({"id": 1, "name": "Name"}, pk="id")
species = fresh_db["species"]
id = species.lookup(
{"name": "Palm", "type": "Tree"},
{
"first_seen": "2020-01-01",
"make_not_null": 1,
"fk_to_other": 1,
"default_is_dog": "cat",
"extract_this": "This is extracted",
"convert_to_upper": "upper",
"make_this_integer": "2",
"this_at_front": 1,
},
pk="renamed_id",
foreign_keys=(("fk_to_other", "other_table", "id"),),
column_order=("this_at_front",),
not_null={"make_not_null"},
defaults={"default_is_dog": "dog"},
extracts=["extract_this"],
conversions={"convert_to_upper": "upper(?)"},
columns={"make_this_integer": int},
)
assert species.schema == (
"CREATE TABLE [species] (\n"
" [renamed_id] INTEGER PRIMARY KEY,\n"
" [this_at_front] INTEGER,\n"
" [name] TEXT,\n"
" [type] TEXT,\n"
" [first_seen] TEXT,\n"
" [make_not_null] INTEGER NOT NULL,\n"
" [fk_to_other] INTEGER REFERENCES [other_table]([id]),\n"
" [default_is_dog] TEXT DEFAULT 'dog',\n"
" [extract_this] INTEGER REFERENCES [extract_this]([id]),\n"
" [convert_to_upper] TEXT,\n"
" [make_this_integer] INTEGER\n"
")"
)
assert species.get(id) == {
"renamed_id": id,
"this_at_front": 1,
"name": "Palm",
"type": "Tree",
"first_seen": "2020-01-01",
"make_not_null": 1,
"fk_to_other": 1,
"default_is_dog": "cat",
"extract_this": 1,
"convert_to_upper": "UPPER",
"make_this_integer": 2,
}
assert species.indexes == [
Index(
seq=0,
name="idx_species_name_type",
unique=1,
origin="c",
partial=0,
columns=["name", "type"],
)
]
@pytest.mark.parametrize("strict", (False, True))
def test_lookup_new_table_strict(fresh_db, strict):
fresh_db["species"].lookup({"name": "Palm"}, strict=strict)
assert fresh_db["species"].strict == strict or not fresh_db.supports_strict
</document_content>
</document>
<document index="42">
<source>./tests/test_m2m.py</source>
<document_content>
from sqlite_utils.db import ForeignKey, NoObviousTable
import pytest
def test_insert_m2m_single(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo"}, pk="id").m2m(
"humans", {"id": 1, "name": "Natalie D"}, pk="id"
)
assert {"dogs_humans", "humans", "dogs"} == set(fresh_db.table_names())
humans = fresh_db["humans"]
dogs_humans = fresh_db["dogs_humans"]
assert [{"id": 1, "name": "Natalie D"}] == list(humans.rows)
assert [{"humans_id": 1, "dogs_id": 1}] == list(dogs_humans.rows)
def test_insert_m2m_alter(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo"}, pk="id").m2m(
"humans", {"id": 1, "name": "Natalie D"}, pk="id"
)
dogs.update(1).m2m(
"humans", {"id": 2, "name": "Simon W", "nerd": True}, pk="id", alter=True
)
assert list(fresh_db["humans"].rows) == [
{"id": 1, "name": "Natalie D", "nerd": None},
{"id": 2, "name": "Simon W", "nerd": 1},
]
assert list(fresh_db["dogs_humans"].rows) == [
{"humans_id": 1, "dogs_id": 1},
{"humans_id": 2, "dogs_id": 1},
]
def test_insert_m2m_list(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo"}, pk="id").m2m(
"humans",
[{"id": 1, "name": "Natalie D"}, {"id": 2, "name": "Simon W"}],
pk="id",
)
assert {"dogs", "humans", "dogs_humans"} == set(fresh_db.table_names())
humans = fresh_db["humans"]
dogs_humans = fresh_db["dogs_humans"]
assert [{"humans_id": 1, "dogs_id": 1}, {"humans_id": 2, "dogs_id": 1}] == list(
dogs_humans.rows
)
assert [{"id": 1, "name": "Natalie D"}, {"id": 2, "name": "Simon W"}] == list(
humans.rows
)
assert [
ForeignKey(
table="dogs_humans", column="dogs_id", other_table="dogs", other_column="id"
),
ForeignKey(
table="dogs_humans",
column="humans_id",
other_table="humans",
other_column="id",
),
] == dogs_humans.foreign_keys
def test_insert_m2m_iterable(fresh_db):
iterable_records = ({"id": 1, "name": "Phineas"}, {"id": 2, "name": "Ferb"})
def iterable():
for record in iterable_records:
yield record
platypuses = fresh_db["platypuses"]
platypuses.insert({"id": 1, "name": "Perry"}, pk="id").m2m(
"humans",
iterable(),
pk="id",
)
assert {"platypuses", "humans", "humans_platypuses"} == set(fresh_db.table_names())
humans = fresh_db["humans"]
humans_platypuses = fresh_db["humans_platypuses"]
assert [
{"humans_id": 1, "platypuses_id": 1},
{"humans_id": 2, "platypuses_id": 1},
] == list(humans_platypuses.rows)
assert [{"id": 1, "name": "Phineas"}, {"id": 2, "name": "Ferb"}] == list(
humans.rows
)
assert [
ForeignKey(
table="humans_platypuses",
column="platypuses_id",
other_table="platypuses",
other_column="id",
),
ForeignKey(
table="humans_platypuses",
column="humans_id",
other_table="humans",
other_column="id",
),
] == humans_platypuses.foreign_keys
def test_m2m_with_table_objects(fresh_db):
dogs = fresh_db.table("dogs", pk="id")
humans = fresh_db.table("humans", pk="id")
dogs.insert({"id": 1, "name": "Cleo"}).m2m(
humans, [{"id": 1, "name": "Natalie D"}, {"id": 2, "name": "Simon W"}]
)
expected_tables = {"dogs", "humans", "dogs_humans"}
assert expected_tables == set(fresh_db.table_names())
assert dogs.count == 1
assert humans.count == 2
assert fresh_db["dogs_humans"].count == 2
def test_m2m_lookup(fresh_db):
people = fresh_db.table("people", pk="id")
people.insert({"name": "Wahyu"}).m2m("tags", lookup={"tag": "Coworker"})
people_tags = fresh_db["people_tags"]
tags = fresh_db["tags"]
assert people_tags.exists()
assert tags.exists()
assert [
ForeignKey(
table="people_tags",
column="people_id",
other_table="people",
other_column="id",
),
ForeignKey(
table="people_tags", column="tags_id", other_table="tags", other_column="id"
),
] == people_tags.foreign_keys
assert [{"people_id": 1, "tags_id": 1}] == list(people_tags.rows)
assert [{"id": 1, "name": "Wahyu"}] == list(people.rows)
assert [{"id": 1, "tag": "Coworker"}] == list(tags.rows)
def test_m2m_requires_either_records_or_lookup(fresh_db):
people = fresh_db.table("people", pk="id").insert({"name": "Wahyu"})
with pytest.raises(AssertionError):
people.m2m("tags")
with pytest.raises(AssertionError):
people.m2m("tags", {"tag": "hello"}, lookup={"foo": "bar"})
def test_m2m_explicit_table_name_argument(fresh_db):
people = fresh_db.table("people", pk="id")
people.insert({"name": "Wahyu"}).m2m(
"tags", lookup={"tag": "Coworker"}, m2m_table="tagged"
)
assert fresh_db["tags"].exists
assert fresh_db["tagged"].exists
assert not fresh_db["people_tags"].exists()
def test_m2m_table_candidates(fresh_db):
fresh_db.create_table("one", {"id": int, "name": str}, pk="id")
fresh_db.create_table("two", {"id": int, "name": str}, pk="id")
fresh_db.create_table("three", {"id": int, "name": str}, pk="id")
# No candidates at first
assert [] == fresh_db.m2m_table_candidates("one", "two")
# Create a candidate
fresh_db.create_table(
"one_m2m_two", {"one_id": int, "two_id": int}, foreign_keys=["one_id", "two_id"]
)
assert ["one_m2m_two"] == fresh_db.m2m_table_candidates("one", "two")
# Add another table and there should be two candidates
fresh_db.create_table(
"one_m2m_two_and_three",
{"one_id": int, "two_id": int, "three_id": int},
foreign_keys=["one_id", "two_id", "three_id"],
)
assert {"one_m2m_two", "one_m2m_two_and_three"} == set(
fresh_db.m2m_table_candidates("one", "two")
)
def test_uses_existing_m2m_table_if_exists(fresh_db):
# Code should look for an existing table with fks to both tables
# and use that if it exists.
people = fresh_db.create_table("people", {"id": int, "name": str}, pk="id")
fresh_db["tags"].lookup({"tag": "Coworker"})
fresh_db.create_table(
"tagged",
{"people_id": int, "tags_id": int},
foreign_keys=["people_id", "tags_id"],
)
people.insert({"name": "Wahyu"}).m2m("tags", lookup={"tag": "Coworker"})
assert fresh_db["tags"].exists()
assert fresh_db["tagged"].exists()
assert not fresh_db["people_tags"].exists()
assert not fresh_db["tags_people"].exists()
assert [{"people_id": 1, "tags_id": 1}] == list(fresh_db["tagged"].rows)
def test_requires_explicit_m2m_table_if_multiple_options(fresh_db):
# If the code scans for m2m tables and finds more than one candidate
# it should require that the m2m_table=x argument is used
people = fresh_db.create_table("people", {"id": int, "name": str}, pk="id")
fresh_db["tags"].lookup({"tag": "Coworker"})
fresh_db.create_table(
"tagged",
{"people_id": int, "tags_id": int},
foreign_keys=["people_id", "tags_id"],
)
fresh_db.create_table(
"tagged2",
{"people_id": int, "tags_id": int},
foreign_keys=["people_id", "tags_id"],
)
with pytest.raises(NoObviousTable):
people.insert({"name": "Wahyu"}).m2m("tags", lookup={"tag": "Coworker"})
</document_content>
</document>
<document index="43">
<source>./tests/test_plugins.py</source>
<document_content>
from click.testing import CliRunner
import click
import importlib
import pytest
from sqlite_utils import cli, Database, hookimpl, plugins
def _supports_pragma_function_list():
db = Database(memory=True)
try:
db.execute("select * from pragma_function_list()")
except Exception:
return False
return True
def test_register_commands():
importlib.reload(cli)
assert plugins.get_plugins() == []
class HelloWorldPlugin:
__name__ = "HelloWorldPlugin"
@hookimpl
def register_commands(self, cli):
@cli.command(name="hello-world")
def hello_world():
"Print hello world"
click.echo("Hello world!")
try:
plugins.pm.register(HelloWorldPlugin(), name="HelloWorldPlugin")
importlib.reload(cli)
assert plugins.get_plugins() == [
{"name": "HelloWorldPlugin", "hooks": ["register_commands"]}
]
runner = CliRunner()
result = runner.invoke(cli.cli, ["hello-world"])
assert result.exit_code == 0
assert result.output == "Hello world!\n"
finally:
plugins.pm.unregister(name="HelloWorldPlugin")
importlib.reload(cli)
assert plugins.get_plugins() == []
@pytest.mark.skipif(
not _supports_pragma_function_list(),
reason="Needs SQLite version that supports pragma_function_list()",
)
def test_prepare_connection():
importlib.reload(cli)
assert plugins.get_plugins() == []
class HelloFunctionPlugin:
__name__ = "HelloFunctionPlugin"
@hookimpl
def prepare_connection(self, conn):
conn.create_function("hello", 1, lambda name: f"Hello, {name}!")
db = Database(memory=True)
def _functions(db):
return [
row[0]
for row in db.execute(
"select distinct name from pragma_function_list() order by 1"
).fetchall()
]
assert "hello" not in _functions(db)
try:
plugins.pm.register(HelloFunctionPlugin(), name="HelloFunctionPlugin")
assert plugins.get_plugins() == [
{"name": "HelloFunctionPlugin", "hooks": ["prepare_connection"]}
]
db = Database(memory=True)
assert "hello" in _functions(db)
result = db.execute('select hello("world")').fetchone()[0]
assert result == "Hello, world!"
# Test execute_plugins=False
db2 = Database(memory=True, execute_plugins=False)
assert "hello" not in _functions(db2)
finally:
plugins.pm.unregister(name="HelloFunctionPlugin")
assert plugins.get_plugins() == []
</document_content>
</document>
<document index="44">
<source>./tests/test_query.py</source>
<document_content>
import types
def test_query(fresh_db):
fresh_db["dogs"].insert_all([{"name": "Cleo"}, {"name": "Pancakes"}])
results = fresh_db.query("select * from dogs order by name desc")
assert isinstance(results, types.GeneratorType)
assert list(results) == [{"name": "Pancakes"}, {"name": "Cleo"}]
def test_execute_returning_dicts(fresh_db):
# Like db.query() but returns a list, included for backwards compatibility
# see https://github.com/simonw/sqlite-utils/issues/290
fresh_db["test"].insert({"id": 1, "bar": 2}, pk="id")
assert fresh_db.execute_returning_dicts("select * from test") == [
{"id": 1, "bar": 2}
]
</document_content>
</document>
<document index="45">
<source>./tests/test_recipes.py</source>
<document_content>
from sqlite_utils import recipes
from sqlite_utils.utils import sqlite3
import json
import pytest
@pytest.fixture
def dates_db(fresh_db):
fresh_db["example"].insert_all(
[
{"id": 1, "dt": "5th October 2019 12:04"},
{"id": 2, "dt": "6th October 2019 00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
],
pk="id",
)
return fresh_db
def test_parsedate(dates_db):
dates_db["example"].convert("dt", recipes.parsedate)
assert list(dates_db["example"].rows) == [
{"id": 1, "dt": "2019-10-05"},
{"id": 2, "dt": "2019-10-06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
]
def test_parsedatetime(dates_db):
dates_db["example"].convert("dt", recipes.parsedatetime)
assert list(dates_db["example"].rows) == [
{"id": 1, "dt": "2019-10-05T12:04:00"},
{"id": 2, "dt": "2019-10-06T00:05:06"},
{"id": 3, "dt": ""},
{"id": 4, "dt": None},
]
@pytest.mark.parametrize(
"recipe,kwargs,expected",
(
("parsedate", {}, "2005-03-04"),
("parsedate", {"dayfirst": True}, "2005-04-03"),
("parsedatetime", {}, "2005-03-04T00:00:00"),
("parsedatetime", {"dayfirst": True}, "2005-04-03T00:00:00"),
),
)
def test_dayfirst_yearfirst(fresh_db, recipe, kwargs, expected):
fresh_db["example"].insert_all(
[
{"id": 1, "dt": "03/04/05"},
],
pk="id",
)
fresh_db["example"].convert(
"dt", lambda value: getattr(recipes, recipe)(value, **kwargs)
)
assert list(fresh_db["example"].rows) == [
{"id": 1, "dt": expected},
]
@pytest.mark.parametrize("fn", ("parsedate", "parsedatetime"))
@pytest.mark.parametrize("errors", (None, recipes.SET_NULL, recipes.IGNORE))
def test_dateparse_errors(fresh_db, fn, errors):
fresh_db["example"].insert_all(
[
{"id": 1, "dt": "invalid"},
],
pk="id",
)
if errors is None:
# Should raise an error
with pytest.raises(sqlite3.OperationalError):
fresh_db["example"].convert("dt", lambda value: getattr(recipes, fn)(value))
else:
fresh_db["example"].convert(
"dt", lambda value: getattr(recipes, fn)(value, errors=errors)
)
rows = list(fresh_db["example"].rows)
expected = [{"id": 1, "dt": None if errors is recipes.SET_NULL else "invalid"}]
assert rows == expected
@pytest.mark.parametrize("delimiter", [None, ";", "-"])
def test_jsonsplit(fresh_db, delimiter):
fresh_db["example"].insert_all(
[
{"id": 1, "tags": (delimiter or ",").join(["foo", "bar"])},
{"id": 2, "tags": (delimiter or ",").join(["bar", "baz"])},
],
pk="id",
)
if delimiter is not None:
def fn(value):
return recipes.jsonsplit(value, delimiter=delimiter)
else:
fn = recipes.jsonsplit
fresh_db["example"].convert("tags", fn)
assert list(fresh_db["example"].rows) == [
{"id": 1, "tags": '["foo", "bar"]'},
{"id": 2, "tags": '["bar", "baz"]'},
]
@pytest.mark.parametrize(
"type,expected",
(
(None, ["1", "2", "3"]),
(float, [1.0, 2.0, 3.0]),
(int, [1, 2, 3]),
),
)
def test_jsonsplit_type(fresh_db, type, expected):
fresh_db["example"].insert_all(
[
{"id": 1, "records": "1,2,3"},
],
pk="id",
)
if type is not None:
def fn(value):
return recipes.jsonsplit(value, type=type)
else:
fn = recipes.jsonsplit
fresh_db["example"].convert("records", fn)
assert json.loads(fresh_db["example"].get(1)["records"]) == expected
</document_content>
</document>
<document index="46">
<source>./tests/test_recreate.py</source>
<document_content>
from sqlite_utils import Database
import sqlite3
import pathlib
import pytest
def test_recreate_ignored_for_in_memory():
# None of these should raise an exception:
Database(memory=True, recreate=False)
Database(memory=True, recreate=True)
Database(":memory:", recreate=False)
Database(":memory:", recreate=True)
def test_recreate_not_allowed_for_connection():
conn = sqlite3.connect(":memory:")
with pytest.raises(AssertionError):
Database(conn, recreate=True)
@pytest.mark.parametrize(
"use_path,create_file_first",
[(True, True), (True, False), (False, True), (False, False)],
)
def test_recreate(tmp_path, use_path, create_file_first):
filepath = str(tmp_path / "data.db")
if use_path:
filepath = pathlib.Path(filepath)
if create_file_first:
db = Database(filepath)
db["t1"].insert({"foo": "bar"})
assert ["t1"] == db.table_names()
db.close()
Database(filepath, recreate=True)["t2"].insert({"foo": "bar"})
assert ["t2"] == Database(filepath).table_names()
</document_content>
</document>
<document index="47">
<source>./tests/test_register_function.py</source>
<document_content>
# flake8: noqa
import pytest
import sys
from unittest.mock import MagicMock, call
from sqlite_utils.utils import sqlite3
def test_register_function(fresh_db):
@fresh_db.register_function
def reverse_string(s):
return "".join(reversed(list(s)))
result = fresh_db.execute('select reverse_string("hello")').fetchone()[0]
assert result == "olleh"
def test_register_function_custom_name(fresh_db):
@fresh_db.register_function(name="revstr")
def reverse_string(s):
return "".join(reversed(list(s)))
result = fresh_db.execute('select revstr("hello")').fetchone()[0]
assert result == "olleh"
def test_register_function_multiple_arguments(fresh_db):
@fresh_db.register_function
def a_times_b_plus_c(a, b, c):
return a * b + c
result = fresh_db.execute("select a_times_b_plus_c(2, 3, 4)").fetchone()[0]
assert result == 10
def test_register_function_deterministic(fresh_db):
@fresh_db.register_function(deterministic=True)
def to_lower(s):
return s.lower()
result = fresh_db.execute("select to_lower('BOB')").fetchone()[0]
assert result == "bob"
def test_register_function_deterministic_tries_again_if_exception_raised(fresh_db):
fresh_db.conn = MagicMock()
fresh_db.conn.create_function = MagicMock()
@fresh_db.register_function(deterministic=True)
def to_lower_2(s):
return s.lower()
fresh_db.conn.create_function.assert_called_with(
"to_lower_2", 1, to_lower_2, deterministic=True
)
first = True
def side_effect(*args, **kwargs):
# Raise exception only first time this is called
nonlocal first
if first:
first = False
raise sqlite3.NotSupportedError()
# But if sqlite3.NotSupportedError is raised, it tries again
fresh_db.conn.create_function.reset_mock()
fresh_db.conn.create_function.side_effect = side_effect
@fresh_db.register_function(deterministic=True)
def to_lower_3(s):
return s.lower()
# Should have been called once with deterministic=True and once without
assert fresh_db.conn.create_function.call_args_list == [
call("to_lower_3", 1, to_lower_3, deterministic=True),
call("to_lower_3", 1, to_lower_3),
]
def test_register_function_replace(fresh_db):
@fresh_db.register_function()
def one():
return "one"
assert "one" == fresh_db.execute("select one()").fetchone()[0]
# This will silently fail to replaec the function
@fresh_db.register_function()
def one(): # noqa
return "two"
assert "one" == fresh_db.execute("select one()").fetchone()[0]
# This will replace it
@fresh_db.register_function(replace=True)
def one(): # noqa
return "two"
assert "two" == fresh_db.execute("select one()").fetchone()[0]
</document_content>
</document>
<document index="48">
<source>./tests/test_rows.py</source>
<document_content>
import pytest
def test_rows(existing_db):
assert [{"text": "one"}, {"text": "two"}, {"text": "three"}] == list(
existing_db["foo"].rows
)
@pytest.mark.parametrize(
"where,where_args,expected_ids",
[
("name = ?", ["Pancakes"], {2}),
("age > ?", [3], {1}),
("age > :age", {"age": 3}, {1}),
("name is not null", [], {1, 2}),
("is_good = ?", [True], {1, 2}),
],
)
def test_rows_where(where, where_args, expected_ids, fresh_db):
table = fresh_db["dogs"]
table.insert_all(
[
{"id": 1, "name": "Cleo", "age": 4, "is_good": True},
{"id": 2, "name": "Pancakes", "age": 3, "is_good": True},
],
pk="id",
)
assert expected_ids == {
r["id"] for r in table.rows_where(where, where_args, select="id")
}
@pytest.mark.parametrize(
"where,order_by,expected_ids",
[
(None, None, [1, 2, 3]),
(None, "id desc", [3, 2, 1]),
(None, "age", [3, 2, 1]),
("id > 1", "age", [3, 2]),
],
)
def test_rows_where_order_by(where, order_by, expected_ids, fresh_db):
table = fresh_db["dogs"]
table.insert_all(
[
{"id": 1, "name": "Cleo", "age": 4},
{"id": 2, "name": "Pancakes", "age": 3},
{"id": 3, "name": "Bailey", "age": 2},
],
pk="id",
)
assert expected_ids == [r["id"] for r in table.rows_where(where, order_by=order_by)]
@pytest.mark.parametrize(
"offset,limit,expected",
[
(None, 3, [1, 2, 3]),
(0, 3, [1, 2, 3]),
(3, 3, [4, 5, 6]),
],
)
def test_rows_where_offset_limit(fresh_db, offset, limit, expected):
table = fresh_db["rows"]
table.insert_all([{"id": id} for id in range(1, 101)], pk="id")
assert table.count == 100
assert expected == [
r["id"] for r in table.rows_where(offset=offset, limit=limit, order_by="id")
]
def test_pks_and_rows_where_rowid(fresh_db):
table = fresh_db["rowid_table"]
table.insert_all({"number": i + 10} for i in range(3))
pks_and_rows = list(table.pks_and_rows_where())
assert pks_and_rows == [
(1, {"rowid": 1, "number": 10}),
(2, {"rowid": 2, "number": 11}),
(3, {"rowid": 3, "number": 12}),
]
def test_pks_and_rows_where_simple_pk(fresh_db):
table = fresh_db["simple_pk_table"]
table.insert_all(({"id": i + 10} for i in range(3)), pk="id")
pks_and_rows = list(table.pks_and_rows_where())
assert pks_and_rows == [
(10, {"id": 10}),
(11, {"id": 11}),
(12, {"id": 12}),
]
def test_pks_and_rows_where_compound_pk(fresh_db):
table = fresh_db["compound_pk_table"]
table.insert_all(
({"type": "number", "number": i, "plusone": i + 1} for i in range(3)),
pk=("type", "number"),
)
pks_and_rows = list(table.pks_and_rows_where())
assert pks_and_rows == [
(("number", 0), {"type": "number", "number": 0, "plusone": 1}),
(("number", 1), {"type": "number", "number": 1, "plusone": 2}),
(("number", 2), {"type": "number", "number": 2, "plusone": 3}),
]
</document_content>
</document>
<document index="49">
<source>./tests/test_rows_from_file.py</source>
<document_content>
from sqlite_utils.utils import rows_from_file, Format, RowError
from io import BytesIO, StringIO
import pytest
@pytest.mark.parametrize(
"input,expected_format",
(
(b"id,name\n1,Cleo", Format.CSV),
(b"id\tname\n1\tCleo", Format.TSV),
(b'[{"id": "1", "name": "Cleo"}]', Format.JSON),
),
)
def test_rows_from_file_detect_format(input, expected_format):
rows, format = rows_from_file(BytesIO(input))
assert format == expected_format
rows_list = list(rows)
assert rows_list == [{"id": "1", "name": "Cleo"}]
@pytest.mark.parametrize(
"ignore_extras,extras_key,expected",
(
(True, None, [{"id": "1", "name": "Cleo"}]),
(False, "_rest", [{"id": "1", "name": "Cleo", "_rest": ["oops"]}]),
# expected of None means expect an error:
(False, False, None),
),
)
def test_rows_from_file_extra_fields_strategies(ignore_extras, extras_key, expected):
try:
rows, format = rows_from_file(
BytesIO(b"id,name\r\n1,Cleo,oops"),
format=Format.CSV,
ignore_extras=ignore_extras,
extras_key=extras_key,
)
list_rows = list(rows)
except RowError:
if expected is None:
# This is fine,
return
else:
# We did not expect an error
raise
assert list_rows == expected
def test_rows_from_file_error_on_string_io():
with pytest.raises(TypeError) as ex:
rows_from_file(StringIO("id,name\r\n1,Cleo"))
assert ex.value.args == (
"rows_from_file() requires a file-like object that supports peek(), such as io.BytesIO",
)
</document_content>
</document>
<document index="50">
<source>./tests/test_sniff.py</source>
<document_content>
from sqlite_utils import cli, Database
from click.testing import CliRunner
import pathlib
import pytest
sniff_dir = pathlib.Path(__file__).parent / "sniff"
@pytest.mark.parametrize("filepath", sniff_dir.glob("example*"))
def test_sniff(tmpdir, filepath):
db_path = str(tmpdir / "test.db")
runner = CliRunner()
result = runner.invoke(
cli.cli,
["insert", db_path, "creatures", str(filepath), "--sniff"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.stdout
db = Database(db_path)
assert list(db["creatures"].rows) == [
{"id": "1", "species": "dog", "name": "Cleo", "age": "5"},
{"id": "2", "species": "dog", "name": "Pancakes", "age": "4"},
{"id": "3", "species": "cat", "name": "Mozie", "age": "8"},
{"id": "4", "species": "spider", "name": "Daisy, the tarantula", "age": "6"},
]
</document_content>
</document>
<document index="51">
<source>./tests/test_suggest_column_types.py</source>
<document_content>
import pytest
from collections import OrderedDict
from sqlite_utils.utils import suggest_column_types
@pytest.mark.parametrize(
"records,types",
[
([{"a": 1}], {"a": int}),
([{"a": 1}, {"a": None}], {"a": int}),
([{"a": "baz"}], {"a": str}),
([{"a": "baz"}, {"a": None}], {"a": str}),
([{"a": 1.2}], {"a": float}),
([{"a": 1.2}, {"a": None}], {"a": float}),
([{"a": [1]}], {"a": str}),
([{"a": [1]}, {"a": None}], {"a": str}),
([{"a": (1,)}], {"a": str}),
([{"a": {"b": 1}}], {"a": str}),
([{"a": {"b": 1}}, {"a": None}], {"a": str}),
([{"a": OrderedDict({"b": 1})}], {"a": str}),
([{"a": 1}, {"a": 1.1}], {"a": float}),
([{"a": b"b"}], {"a": bytes}),
([{"a": b"b"}, {"a": None}], {"a": bytes}),
([{"a": "a", "b": None}], {"a": str, "b": str}),
],
)
def test_suggest_column_types(records, types):
assert types == suggest_column_types(records)
</document_content>
</document>
<document index="52">
<source>./tests/test_tracer.py</source>
<document_content>
from sqlite_utils import Database
def test_tracer():
collected = []
db = Database(
memory=True, tracer=lambda sql, params: collected.append((sql, params))
)
db["dogs"].insert({"name": "Cleopaws"})
db["dogs"].enable_fts(["name"])
db["dogs"].search("Cleopaws")
assert collected == [
("PRAGMA recursive_triggers=on;", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'table'", None),
("select name from sqlite_master where type = 'view'", None),
("select name from sqlite_master where type = 'table'", None),
("select name from sqlite_master where type = 'view'", None),
("CREATE TABLE [dogs] (\n [name] TEXT\n);\n ", None),
("select name from sqlite_master where type = 'view'", None),
("INSERT INTO [dogs] ([name]) VALUES (?)", ["Cleopaws"]),
("select name from sqlite_master where type = 'view'", None),
(
"CREATE VIRTUAL TABLE [dogs_fts] USING FTS5 (\n [name],\n content=[dogs]\n)",
None,
),
(
"INSERT INTO [dogs_fts] (rowid, [name])\n SELECT rowid, [name] FROM [dogs];",
None,
),
("select name from sqlite_master where type = 'view'", None),
]
def test_with_tracer():
collected = []
def tracer(sql, params):
return collected.append((sql, params))
db = Database(memory=True)
db["dogs"].insert({"name": "Cleopaws"})
db["dogs"].enable_fts(["name"])
assert len(collected) == 0
with db.tracer(tracer):
list(db["dogs"].search("Cleopaws"))
assert len(collected) == 5
assert collected == [
("select name from sqlite_master where type = 'view'", None),
(
(
"SELECT name FROM sqlite_master\n"
" WHERE rootpage = 0\n"
" AND (\n"
" sql LIKE :like\n"
" OR sql LIKE :like2\n"
" OR (\n"
" tbl_name = :table\n"
" AND sql LIKE '%VIRTUAL TABLE%USING FTS%'\n"
" )\n"
" )",
{
"like": "%VIRTUAL TABLE%USING FTS%content=[dogs]%",
"like2": '%VIRTUAL TABLE%USING FTS%content="dogs"%',
"table": "dogs",
},
)
),
("select name from sqlite_master where type = 'view'", None),
("select sql from sqlite_master where name = ?", ("dogs_fts",)),
(
(
"with original as (\n"
" select\n"
" rowid,\n"
" *\n"
" from [dogs]\n"
")\n"
"select\n"
" [original].*\n"
"from\n"
" [original]\n"
" join [dogs_fts] on [original].rowid = [dogs_fts].rowid\n"
"where\n"
" [dogs_fts] match :query\n"
"order by\n"
" [dogs_fts].rank"
),
{"query": "Cleopaws"},
),
]
# Outside the with block collected should not be appended to
db["dogs"].insert({"name": "Cleopaws"})
assert len(collected) == 5
</document_content>
</document>
<document index="53">
<source>./tests/test_transform.py</source>
<document_content>
from sqlite_utils.db import ForeignKey, TransformError
from sqlite_utils.utils import OperationalError
import pytest
@pytest.mark.parametrize(
"params,expected_sql",
[
# Identity transform - nothing changes
(
{},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Change column type
(
{"types": {"age": int}},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] INTEGER\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Rename a column
(
{"rename": {"age": "dog_age"}},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [dog_age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [dog_age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Drop a column
(
{"drop": ["age"]},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name])\n SELECT [rowid], [id], [name] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Convert type AND rename column
(
{"types": {"age": int}, "rename": {"age": "dog_age"}},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [dog_age] INTEGER\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [dog_age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Change primary key
(
{"pk": "age"},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [age] TEXT PRIMARY KEY\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Change primary key to a compound pk
(
{"pk": ("age", "name")},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [age] TEXT,\n PRIMARY KEY ([age], [name])\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Remove primary key, creating a rowid table
(
{"pk": None},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Keeping the table
(
{"drop": ["age"], "keep_table": "kept_table"},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name])\n SELECT [rowid], [id], [name] FROM [dogs];",
"ALTER TABLE [dogs] RENAME TO [kept_table];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
],
)
@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True])
def test_transform_sql_table_with_primary_key(
fresh_db, params, expected_sql, use_pragma_foreign_keys
):
captured = []
def tracer(sql, params):
return captured.append((sql, params))
dogs = fresh_db["dogs"]
if use_pragma_foreign_keys:
fresh_db.conn.execute("PRAGMA foreign_keys=ON")
dogs.insert({"id": 1, "name": "Cleo", "age": "5"}, pk="id")
sql = dogs.transform_sql(**{**params, **{"tmp_suffix": "suffix"}})
assert sql == expected_sql
# Check that .transform() runs without exceptions:
with fresh_db.tracer(tracer):
dogs.transform(**params)
# If use_pragma_foreign_keys, check that we did the right thing
if use_pragma_foreign_keys:
assert ("PRAGMA foreign_keys=0;", None) in captured
assert captured[-2] == ("PRAGMA foreign_key_check;", None)
assert captured[-1] == ("PRAGMA foreign_keys=1;", None)
else:
assert ("PRAGMA foreign_keys=0;", None) not in captured
assert ("PRAGMA foreign_keys=1;", None) not in captured
@pytest.mark.parametrize(
"params,expected_sql",
[
# Identity transform - nothing changes
(
{},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Change column type
(
{"types": {"age": int}},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [age] INTEGER\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Rename a column
(
{"rename": {"age": "dog_age"}},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER,\n [name] TEXT,\n [dog_age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [dog_age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
# Make ID a primary key
(
{"pk": "id"},
[
"CREATE TABLE [dogs_new_suffix] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] TEXT\n);",
"INSERT INTO [dogs_new_suffix] ([rowid], [id], [name], [age])\n SELECT [rowid], [id], [name], [age] FROM [dogs];",
"DROP TABLE [dogs];",
"ALTER TABLE [dogs_new_suffix] RENAME TO [dogs];",
],
),
],
)
@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True])
def test_transform_sql_table_with_no_primary_key(
fresh_db, params, expected_sql, use_pragma_foreign_keys
):
captured = []
def tracer(sql, params):
return captured.append((sql, params))
dogs = fresh_db["dogs"]
if use_pragma_foreign_keys:
fresh_db.conn.execute("PRAGMA foreign_keys=ON")
dogs.insert({"id": 1, "name": "Cleo", "age": "5"})
sql = dogs.transform_sql(**{**params, **{"tmp_suffix": "suffix"}})
assert sql == expected_sql
# Check that .transform() runs without exceptions:
with fresh_db.tracer(tracer):
dogs.transform(**params)
# If use_pragma_foreign_keys, check that we did the right thing
if use_pragma_foreign_keys:
assert ("PRAGMA foreign_keys=0;", None) in captured
assert captured[-2] == ("PRAGMA foreign_key_check;", None)
assert captured[-1] == ("PRAGMA foreign_keys=1;", None)
else:
assert ("PRAGMA foreign_keys=0;", None) not in captured
assert ("PRAGMA foreign_keys=1;", None) not in captured
def test_transform_sql_with_no_primary_key_to_primary_key_of_id(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": "5"})
assert (
dogs.schema
== "CREATE TABLE [dogs] (\n [id] INTEGER,\n [name] TEXT,\n [age] TEXT\n)"
)
dogs.transform(pk="id")
# Slight oddity: [dogs] becomes "dogs" during the rename:
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] TEXT\n)'
)
def test_transform_rename_pk(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": "5"}, pk="id")
dogs.transform(rename={"id": "pk"})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [pk] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] TEXT\n)'
)
def test_transform_not_null(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": "5"}, pk="id")
dogs.transform(not_null={"name"})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT NOT NULL,\n [age] TEXT\n)'
)
def test_transform_remove_a_not_null(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": "5"}, not_null={"age"}, pk="id")
dogs.transform(not_null={"name": True, "age": False})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT NOT NULL,\n [age] TEXT\n)'
)
@pytest.mark.parametrize("not_null", [{"age"}, {"age": True}])
def test_transform_add_not_null_with_rename(fresh_db, not_null):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": "5"}, pk="id")
dogs.transform(not_null=not_null, rename={"age": "dog_age"})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [dog_age] TEXT NOT NULL\n)'
)
def test_transform_defaults(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": 5}, pk="id")
dogs.transform(defaults={"age": 1})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] INTEGER DEFAULT 1\n)'
)
def test_transform_defaults_and_rename_column(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": 5}, pk="id")
dogs.transform(rename={"age": "dog_age"}, defaults={"age": 1})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [dog_age] INTEGER DEFAULT 1\n)'
)
def test_remove_defaults(fresh_db):
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": 5}, defaults={"age": 1}, pk="id")
dogs.transform(defaults={"age": None})
assert (
dogs.schema
== 'CREATE TABLE "dogs" (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT,\n [age] INTEGER\n)'
)
@pytest.fixture
def authors_db(fresh_db):
books = fresh_db["books"]
authors = fresh_db["authors"]
authors.insert({"id": 5, "name": "Jane McGonical"}, pk="id")
books.insert(
{"id": 2, "title": "Reality is Broken", "author_id": 5},
foreign_keys=("author_id",),
pk="id",
)
return fresh_db
def test_transform_foreign_keys_persist(authors_db):
assert authors_db["books"].foreign_keys == [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
)
]
authors_db["books"].transform(rename={"title": "book_title"})
assert authors_db["books"].foreign_keys == [
ForeignKey(
table="books", column="author_id", other_table="authors", other_column="id"
)
]
@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True])
def test_transform_foreign_keys_survive_renamed_column(
authors_db, use_pragma_foreign_keys
):
if use_pragma_foreign_keys:
authors_db.conn.execute("PRAGMA foreign_keys=ON")
authors_db["books"].transform(rename={"author_id": "author_id_2"})
assert authors_db["books"].foreign_keys == [
ForeignKey(
table="books",
column="author_id_2",
other_table="authors",
other_column="id",
)
]
def _add_country_city_continent(db):
db["country"].insert({"id": 1, "name": "France"}, pk="id")
db["continent"].insert({"id": 2, "name": "Europe"}, pk="id")
db["city"].insert({"id": 24, "name": "Paris"}, pk="id")
_CAVEAU = {
"id": 32,
"name": "Caveau de la Huchette",
"country": 1,
"continent": 2,
"city": 24,
}
@pytest.mark.parametrize("use_pragma_foreign_keys", [False, True])
def test_transform_drop_foreign_keys(fresh_db, use_pragma_foreign_keys):
if use_pragma_foreign_keys:
fresh_db.conn.execute("PRAGMA foreign_keys=ON")
# Create table with three foreign keys so we can drop two of them
_add_country_city_continent(fresh_db)
fresh_db["places"].insert(
_CAVEAU,
foreign_keys=("country", "continent", "city"),
)
assert fresh_db["places"].foreign_keys == [
ForeignKey(
table="places", column="city", other_table="city", other_column="id"
),
ForeignKey(
table="places",
column="continent",
other_table="continent",
other_column="id",
),
ForeignKey(
table="places", column="country", other_table="country", other_column="id"
),
]
# Drop two of those foreign keys
fresh_db["places"].transform(drop_foreign_keys=("country", "continent"))
# Should be only one foreign key now
assert fresh_db["places"].foreign_keys == [
ForeignKey(table="places", column="city", other_table="city", other_column="id")
]
if use_pragma_foreign_keys:
assert fresh_db.conn.execute("PRAGMA foreign_keys").fetchone()[0]
def test_transform_verify_foreign_keys(fresh_db):
fresh_db.conn.execute("PRAGMA foreign_keys=ON")
fresh_db["authors"].insert({"id": 3, "name": "Tina"}, pk="id")
fresh_db["books"].insert(
{"id": 1, "title": "Book", "author_id": 3}, pk="id", foreign_keys={"author_id"}
)
# Renaming the id column on authors should break everything
with pytest.raises(OperationalError) as e:
fresh_db["authors"].transform(rename={"id": "id2"})
assert e.value.args[0] == 'foreign key mismatch - "books" referencing "authors"'
# This should have rolled us back
assert (
fresh_db["authors"].schema
== "CREATE TABLE [authors] (\n [id] INTEGER PRIMARY KEY,\n [name] TEXT\n)"
)
assert fresh_db.conn.execute("PRAGMA foreign_keys").fetchone()[0]
def test_transform_add_foreign_keys_from_scratch(fresh_db):
_add_country_city_continent(fresh_db)
fresh_db["places"].insert(_CAVEAU)
# Should have no foreign keys
assert fresh_db["places"].foreign_keys == []
# Now add them using .transform()
fresh_db["places"].transform(add_foreign_keys=("country", "continent", "city"))
# Should now have all three:
assert fresh_db["places"].foreign_keys == [
ForeignKey(
table="places", column="city", other_table="city", other_column="id"
),
ForeignKey(
table="places",
column="continent",
other_table="continent",
other_column="id",
),
ForeignKey(
table="places", column="country", other_table="country", other_column="id"
),
]
assert fresh_db["places"].schema == (
'CREATE TABLE "places" (\n'
" [id] INTEGER,\n"
" [name] TEXT,\n"
" [country] INTEGER REFERENCES [country]([id]),\n"
" [continent] INTEGER REFERENCES [continent]([id]),\n"
" [city] INTEGER REFERENCES [city]([id])\n"
")"
)
@pytest.mark.parametrize(
"add_foreign_keys",
(
("country", "continent"),
# Fully specified
(
("country", "country", "id"),
("continent", "continent", "id"),
),
),
)
def test_transform_add_foreign_keys_from_partial(fresh_db, add_foreign_keys):
_add_country_city_continent(fresh_db)
fresh_db["places"].insert(
_CAVEAU,
foreign_keys=("city",),
)
# Should have one foreign keys
assert fresh_db["places"].foreign_keys == [
ForeignKey(table="places", column="city", other_table="city", other_column="id")
]
# Now add three more using .transform()
fresh_db["places"].transform(add_foreign_keys=add_foreign_keys)
# Should now have all three:
assert fresh_db["places"].foreign_keys == [
ForeignKey(
table="places", column="city", other_table="city", other_column="id"
),
ForeignKey(
table="places",
column="continent",
other_table="continent",
other_column="id",
),
ForeignKey(
table="places", column="country", other_table="country", other_column="id"
),
]
@pytest.mark.parametrize(
"foreign_keys",
(
("country", "continent"),
# Fully specified
(
("country", "country", "id"),
("continent", "continent", "id"),
),
),
)
def test_transform_replace_foreign_keys(fresh_db, foreign_keys):
_add_country_city_continent(fresh_db)
fresh_db["places"].insert(
_CAVEAU,
foreign_keys=("city",),
)
assert len(fresh_db["places"].foreign_keys) == 1
# Replace with two different ones
fresh_db["places"].transform(foreign_keys=foreign_keys)
assert fresh_db["places"].schema == (
'CREATE TABLE "places" (\n'
" [id] INTEGER,\n"
" [name] TEXT,\n"
" [country] INTEGER REFERENCES [country]([id]),\n"
" [continent] INTEGER REFERENCES [continent]([id]),\n"
" [city] INTEGER\n"
")"
)
@pytest.mark.parametrize("table_type", ("id_pk", "rowid", "compound_pk"))
def test_transform_preserves_rowids(fresh_db, table_type):
pk = None
if table_type == "id_pk":
pk = "id"
elif table_type == "compound_pk":
pk = ("id", "name")
elif table_type == "rowid":
pk = None
fresh_db["places"].insert_all(
[
{"id": "1", "name": "Paris", "country": "France"},
{"id": "2", "name": "London", "country": "UK"},
{"id": "3", "name": "New York", "country": "USA"},
],
pk=pk,
)
# Now delete and insert a row to mix up the `rowid` sequence
fresh_db["places"].delete_where("id = ?", ["2"])
fresh_db["places"].insert({"id": "4", "name": "London", "country": "UK"})
previous_rows = list(
tuple(row) for row in fresh_db.execute("select rowid, id, name from places")
)
# Transform it
fresh_db["places"].transform(column_order=("country", "name"))
# Should be the same
next_rows = list(
tuple(row) for row in fresh_db.execute("select rowid, id, name from places")
)
assert previous_rows == next_rows
@pytest.mark.parametrize("strict", (False, True))
def test_transform_strict(fresh_db, strict):
dogs = fresh_db.table("dogs", strict=strict)
dogs.insert({"id": 1, "name": "Cleo"})
assert dogs.strict == strict or not fresh_db.supports_strict
dogs.transform(not_null={"name"})
assert dogs.strict == strict or not fresh_db.supports_strict
@pytest.mark.parametrize(
"indexes, transform_params",
[
([["name"]], {"types": {"age": str}}),
([["name"], ["age", "breed"]], {"types": {"age": str}}),
([], {"types": {"age": str}}),
([["name"]], {"types": {"age": str}, "keep_table": "old_dogs"}),
],
)
def test_transform_indexes(fresh_db, indexes, transform_params):
# https://github.com/simonw/sqlite-utils/issues/633
# New table should have same indexes as old table after transformation
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": 5, "breed": "Labrador"}, pk="id")
for index in indexes:
dogs.create_index(index)
indexes_before_transform = dogs.indexes
dogs.transform(**transform_params)
assert sorted(
[
{k: v for k, v in idx._asdict().items() if k != "seq"}
for idx in dogs.indexes
],
key=lambda x: x["name"],
) == sorted(
[
{k: v for k, v in idx._asdict().items() if k != "seq"}
for idx in indexes_before_transform
],
key=lambda x: x["name"],
), f"Indexes before transform: {indexes_before_transform}\nIndexes after transform: {dogs.indexes}"
if "keep_table" in transform_params:
assert all(
index.origin == "pk"
for index in fresh_db[transform_params["keep_table"]].indexes
)
def test_transform_retains_indexes_with_foreign_keys(fresh_db):
dogs = fresh_db["dogs"]
owners = fresh_db["owners"]
dogs.insert({"id": 1, "name": "Cleo", "owner_id": 1}, pk="id")
owners.insert({"id": 1, "name": "Alice"}, pk="id")
dogs.create_index(["name"])
indexes_before_transform = dogs.indexes
fresh_db.add_foreign_keys([("dogs", "owner_id", "owners", "id")]) # calls transform
assert sorted(
[
{k: v for k, v in idx._asdict().items() if k != "seq"}
for idx in dogs.indexes
],
key=lambda x: x["name"],
) == sorted(
[
{k: v for k, v in idx._asdict().items() if k != "seq"}
for idx in indexes_before_transform
],
key=lambda x: x["name"],
), f"Indexes before transform: {indexes_before_transform}\nIndexes after transform: {dogs.indexes}"
@pytest.mark.parametrize(
"transform_params",
[
{"rename": {"age": "dog_age"}},
{"drop": ["age"]},
],
)
def test_transform_with_indexes_errors(fresh_db, transform_params):
# Should error with a compound (name, age) index if age is renamed or dropped
dogs = fresh_db["dogs"]
dogs.insert({"id": 1, "name": "Cleo", "age": 5}, pk="id")
dogs.create_index(["name", "age"])
with pytest.raises(TransformError) as excinfo:
dogs.transform(**transform_params)
assert (
"Index 'idx_dogs_name_age' column 'age' is not in updated table 'dogs'. "
"You must manually drop this index prior to running this transformation"
in str(excinfo.value)
)
def test_transform_with_unique_constraint_implicit_index(fresh_db):
dogs = fresh_db["dogs"]
# Create a table with a UNIQUE constraint on 'name', which creates an implicit index
fresh_db.execute(
"""
CREATE TABLE dogs (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE,
age INTEGER
);
"""
)
dogs.insert({"id": 1, "name": "Cleo", "age": 5})
# Attempt to transform the table without modifying 'name'
with pytest.raises(TransformError) as excinfo:
dogs.transform(types={"age": str})
assert (
"Index 'sqlite_autoindex_dogs_1' on table 'dogs' does not have a CREATE INDEX statement."
in str(excinfo.value)
)
assert (
"You must manually drop this index prior to running this transformation and manually recreate the new index after running this transformation."
in str(excinfo.value)
)
</document_content>
</document>
<document index="54">
<source>./tests/test_update.py</source>
<document_content>
import collections
import json
import pytest
from sqlite_utils.db import NotFoundError
def test_update_rowid_table(fresh_db):
table = fresh_db["table"]
rowid = table.insert({"foo": "bar"}).last_pk
table.update(rowid, {"foo": "baz"})
assert [{"foo": "baz"}] == list(table.rows)
def test_update_pk_table(fresh_db):
table = fresh_db["table"]
pk = table.insert({"foo": "bar", "id": 5}, pk="id").last_pk
assert 5 == pk
table.update(pk, {"foo": "baz"})
assert [{"id": 5, "foo": "baz"}] == list(table.rows)
def test_update_compound_pk_table(fresh_db):
table = fresh_db["table"]
pk = table.insert({"id1": 5, "id2": 3, "v": 1}, pk=("id1", "id2")).last_pk
assert (5, 3) == pk
table.update(pk, {"v": 2})
assert [{"id1": 5, "id2": 3, "v": 2}] == list(table.rows)
@pytest.mark.parametrize(
"pk,update_pk",
(
(None, 2),
(None, None),
("id1", None),
("id1", 4),
(("id1", "id2"), None),
(("id1", "id2"), 4),
(("id1", "id2"), (4, 5)),
),
)
def test_update_invalid_pk(fresh_db, pk, update_pk):
table = fresh_db["table"]
table.insert({"id1": 5, "id2": 3, "v": 1}, pk=pk).last_pk
with pytest.raises(NotFoundError):
table.update(update_pk, {"v": 2})
def test_update_alter(fresh_db):
table = fresh_db["table"]
rowid = table.insert({"foo": "bar"}).last_pk
table.update(rowid, {"new_col": 1.2}, alter=True)
assert [{"foo": "bar", "new_col": 1.2}] == list(table.rows)
# Let's try adding three cols at once
table.update(
rowid,
{"str_col": "str", "bytes_col": b"\xa0 has bytes", "int_col": -10},
alter=True,
)
assert [
{
"foo": "bar",
"new_col": 1.2,
"str_col": "str",
"bytes_col": b"\xa0 has bytes",
"int_col": -10,
}
] == list(table.rows)
def test_update_alter_with_invalid_column_characters(fresh_db):
table = fresh_db["table"]
rowid = table.insert({"foo": "bar"}).last_pk
with pytest.raises(AssertionError):
table.update(rowid, {"new_col[abc]": 1.2}, alter=True)
def test_update_with_no_values_sets_last_pk(fresh_db):
table = fresh_db.table("dogs", pk="id")
table.insert_all([{"id": 1, "name": "Cleo"}, {"id": 2, "name": "Pancakes"}])
table.update(1)
assert table.last_pk == 1
table.update(2)
assert table.last_pk == 2
with pytest.raises(NotFoundError):
table.update(3)
@pytest.mark.parametrize(
"data_structure",
(
["list with one item"],
["list with", "two items"],
{"dictionary": "simple"},
{"dictionary": {"nested": "complex"}},
collections.OrderedDict(
[
("key1", {"nested": "complex"}),
("key2", "foo"),
]
),
[{"list": "of"}, {"two": "dicts"}],
),
)
def test_update_dictionaries_and_lists_as_json(fresh_db, data_structure):
fresh_db["test"].insert({"id": 1, "data": ""}, pk="id")
fresh_db["test"].update(1, {"data": data_structure})
row = fresh_db.execute("select id, data from test").fetchone()
assert row[0] == 1
assert data_structure == json.loads(row[1])
</document_content>
</document>
<document index="55">
<source>./tests/test_upsert.py</source>
<document_content>
from sqlite_utils.db import PrimaryKeyRequired
import pytest
def test_upsert(fresh_db):
table = fresh_db["table"]
table.insert({"id": 1, "name": "Cleo"}, pk="id")
table.upsert({"id": 1, "age": 5}, pk="id", alter=True)
assert list(table.rows) == [{"id": 1, "name": "Cleo", "age": 5}]
assert table.last_pk == 1
def test_upsert_all(fresh_db):
table = fresh_db["table"]
table.upsert_all([{"id": 1, "name": "Cleo"}, {"id": 2, "name": "Nixie"}], pk="id")
table.upsert_all([{"id": 1, "age": 5}, {"id": 2, "age": 5}], pk="id", alter=True)
assert list(table.rows) == [
{"id": 1, "name": "Cleo", "age": 5},
{"id": 2, "name": "Nixie", "age": 5},
]
assert table.last_pk is None
def test_upsert_all_single_column(fresh_db):
table = fresh_db["table"]
table.upsert_all([{"name": "Cleo"}], pk="name")
assert list(table.rows) == [{"name": "Cleo"}]
assert table.pks == ["name"]
def test_upsert_all_not_null(fresh_db):
# https://github.com/simonw/sqlite-utils/issues/538
fresh_db["comments"].upsert_all(
[{"id": 1, "name": "Cleo"}],
pk="id",
not_null=["name"],
)
assert list(fresh_db["comments"].rows) == [{"id": 1, "name": "Cleo"}]
def test_upsert_error_if_no_pk(fresh_db):
table = fresh_db["table"]
with pytest.raises(PrimaryKeyRequired):
table.upsert_all([{"id": 1, "name": "Cleo"}])
with pytest.raises(PrimaryKeyRequired):
table.upsert({"id": 1, "name": "Cleo"})
def test_upsert_with_hash_id(fresh_db):
table = fresh_db["table"]
table.upsert({"foo": "bar"}, hash_id="pk")
assert [{"pk": "a5e744d0164540d33b1d7ea616c28f2fa97e754a", "foo": "bar"}] == list(
table.rows
)
assert "a5e744d0164540d33b1d7ea616c28f2fa97e754a" == table.last_pk
@pytest.mark.parametrize("hash_id", (None, "custom_id"))
def test_upsert_with_hash_id_columns(fresh_db, hash_id):
table = fresh_db["table"]
table.upsert({"a": 1, "b": 2, "c": 3}, hash_id=hash_id, hash_id_columns=("a", "b"))
assert list(table.rows) == [
{
hash_id or "id": "4acc71e0547112eb432f0a36fb1924c4a738cb49",
"a": 1,
"b": 2,
"c": 3,
}
]
assert table.last_pk == "4acc71e0547112eb432f0a36fb1924c4a738cb49"
table.upsert({"a": 1, "b": 2, "c": 4}, hash_id=hash_id, hash_id_columns=("a", "b"))
assert list(table.rows) == [
{
hash_id or "id": "4acc71e0547112eb432f0a36fb1924c4a738cb49",
"a": 1,
"b": 2,
"c": 4,
}
]
def test_upsert_compound_primary_key(fresh_db):
table = fresh_db["table"]
table.upsert_all(
[
{"species": "dog", "id": 1, "name": "Cleo", "age": 4},
{"species": "cat", "id": 1, "name": "Catbag"},
],
pk=("species", "id"),
)
assert table.last_pk is None
table.upsert({"species": "dog", "id": 1, "age": 5}, pk=("species", "id"))
assert ("dog", 1) == table.last_pk
assert [
{"species": "dog", "id": 1, "name": "Cleo", "age": 5},
{"species": "cat", "id": 1, "name": "Catbag", "age": None},
] == list(table.rows)
# .upsert_all() with a single item should set .last_pk
table.upsert_all([{"species": "cat", "id": 1, "age": 5}], pk=("species", "id"))
assert ("cat", 1) == table.last_pk
</document_content>
</document>
<document index="56">
<source>./tests/test_utils.py</source>
<document_content>
from sqlite_utils import utils
import csv
import io
import pytest
@pytest.mark.parametrize(
"input,expected,should_be_is",
[
({}, None, True),
({"foo": "bar"}, None, True),
(
{"content": {"$base64": True, "encoded": "aGVsbG8="}},
{"content": b"hello"},
False,
),
],
)
def test_decode_base64_values(input, expected, should_be_is):
actual = utils.decode_base64_values(input)
if should_be_is:
assert actual is input
else:
assert actual == expected
@pytest.mark.parametrize(
"size,expected",
(
(1, [["a"], ["b"], ["c"], ["d"]]),
(2, [["a", "b"], ["c", "d"]]),
(3, [["a", "b", "c"], ["d"]]),
(4, [["a", "b", "c", "d"]]),
),
)
def test_chunks(size, expected):
input = ["a", "b", "c", "d"]
chunks = list(map(list, utils.chunks(input, size)))
assert chunks == expected
def test_hash_record():
expected = "d383e7c0ba88f5ffcdd09be660de164b3847401a"
assert utils.hash_record({"name": "Cleo", "twitter": "CleoPaws"}) == expected
assert (
utils.hash_record(
{"name": "Cleo", "twitter": "CleoPaws", "age": 7}, keys=("name", "twitter")
)
== expected
)
assert (
utils.hash_record({"name": "Cleo", "twitter": "CleoPaws", "age": 7}) != expected
)
def test_maximize_csv_field_size_limit():
# Reset to default in case other tests have changed it
csv.field_size_limit(utils.ORIGINAL_CSV_FIELD_SIZE_LIMIT)
long_value = "a" * 131073
long_csv = "id,text\n1,{}".format(long_value)
fp = io.BytesIO(long_csv.encode("utf-8"))
# Using rows_from_file should error
with pytest.raises(csv.Error):
rows, _ = utils.rows_from_file(fp, utils.Format.CSV)
list(rows)
# But if we call maximize_csv_field_size_limit() first it should be OK:
utils.maximize_csv_field_size_limit()
fp2 = io.BytesIO(long_csv.encode("utf-8"))
rows2, _ = utils.rows_from_file(fp2, utils.Format.CSV)
rows_list2 = list(rows2)
assert len(rows_list2) == 1
assert rows_list2[0]["id"] == "1"
assert rows_list2[0]["text"] == long_value
@pytest.mark.parametrize(
"input,expected",
(
({"foo": {"bar": 1}}, {"foo_bar": 1}),
({"foo": {"bar": [1, 2, {"baz": 3}]}}, {"foo_bar": [1, 2, {"baz": 3}]}),
({"foo": {"bar": 1, "baz": {"three": 3}}}, {"foo_bar": 1, "foo_baz_three": 3}),
),
)
def test_flatten(input, expected):
assert utils.flatten(input) == expected
</document_content>
</document>
<document index="57">
<source>./tests/test_wal.py</source>
<document_content>
import pytest
from sqlite_utils import Database
@pytest.fixture
def db_path_tmpdir(tmpdir):
path = tmpdir / "test.db"
db = Database(str(path))
return db, path, tmpdir
def test_enable_disable_wal(db_path_tmpdir):
db, path, tmpdir = db_path_tmpdir
assert len(tmpdir.listdir()) == 1
assert "delete" == db.journal_mode
assert "test.db-wal" not in [f.basename for f in tmpdir.listdir()]
db.enable_wal()
assert "wal" == db.journal_mode
db["test"].insert({"foo": "bar"})
assert "test.db-wal" in [f.basename for f in tmpdir.listdir()]
db.disable_wal()
assert "delete" == db.journal_mode
assert "test.db-wal" not in [f.basename for f in tmpdir.listdir()]
</document_content>
</document>
<document index="58">
<source>./docs/conf.py</source>
<document_content>
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from subprocess import Popen, PIPE
from beanbag_docutils.sphinx.ext.github import github_linkcode_resolve
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.extlinks",
"sphinx.ext.autodoc",
"sphinx_copybutton",
"sphinx.ext.linkcode",
]
autodoc_member_order = "bysource"
autodoc_typehints = "description"
extlinks = {
"issue": ("https://github.com/simonw/sqlite-utils/issues/%s", "#%s"),
}
def linkcode_resolve(domain, info):
return github_linkcode_resolve(
domain=domain,
info=info,
allowed_module_names=["sqlite_utils"],
github_org_id="simonw",
github_repo_id="sqlite-utils",
branch="main",
)
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = ".rst"
# The master toctree document.
master_doc = "index"
# General information about the project.
project = "sqlite-utils"
copyright = "2018-2022, Simon Willison"
author = "Simon Willison"
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
pipe = Popen("git describe --tags --always", stdout=PIPE, shell=True)
git_version = pipe.stdout.read().decode("utf8")
if git_version:
version = git_version.rsplit("-", 1)[0]
release = git_version
else:
version = ""
release = ""
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = "sphinx"
# Only syntax highlight of code-block is used:
highlight_language = "none"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = "furo"
html_title = "sqlite-utils"
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["_static"]
html_js_files = ["js/custom.js"]
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = "sqlite-utils-doc"
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(
master_doc,
"sqlite-utils.tex",
"sqlite-utils documentation",
"Simon Willison",
"manual",
)
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "sqlite-utils", "sqlite-utils documentation", [author], 1)]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(
master_doc,
"sqlite-utils",
"sqlite-utils documentation",
author,
"sqlite-utils",
"Python library for manipulating SQLite databases",
"Miscellaneous",
)
]
</document_content>
</document>
</documents>
Analyze cause of test failures and suggest fixes
Okay, let's analyze the test failures from the provided output.
We see 7 failures in total:
tests/test_extracts.py::test_extracts[...]
(6 variants due to parametrization)tests/test_lookup.py::test_lookup_with_extra_insert_parameters
(1 failure)
Both sets of failures point to issues related to the extracts
feature, which automatically extracts unique values from a column into a separate lookup table and replaces the original column with a foreign key reference.
Analysis of test_extracts
Failures (6 variants):
- Location:
tests/test_extracts.py:51
- Assertion:
assert [...] == fresh_db[expected_table].indexes
- Error:
AssertionError: assert [Index(seq=0,...ns=['value'])] == []
- Cause: The test expects a unique index to be created on the
value
column of the newly extracted table (e.g.,Species
orspecies_id
). The assertion shows that the test expects an index namedidx_<table_name>_value
withunique=1
andcolumns=['value']
, butfresh_db[expected_table].indexes
returns an empty list[]
.
The test_extracts
function tests the extracts
parameter when used with db.table(...).insert_all(...)
or db.table(...).insert(...)
. The extracts
feature, when used with Table.extract()
, does create this unique index on the lookup table's value
column (as seen in sqlite_utils/db.py
, Table.extract
method, around line 1100: lookup_table.create_index(lookup_columns, unique=True, if_not_exists=True)
).
However, when extracts
is passed to insert
or insert_all
, the insert_all
method's internal logic for creating the table (which calls Table.create
, which calls Database.create_table_sql
) appears to correctly configure the foreign key relationship in the schema but does not create the unique index on the lookup table's value
column.
Analysis of test_lookup_with_extra_insert_parameters
Failure:
- Location:
tests/test_lookup.py:131
- Assertion:
assert species.get(id) == { ... }
- Error:
AssertionError: ... Differing items: {'extract_this': 'This is extracted'} != {'extract_this': 1}
- Cause: The test uses
Table.lookup()
with theextracts
parameter (extracts=["extract_this"]
). The lookup dictionary includes"extract_this": "This is extracted"
. The test expects that after the lookup (which creates the record and performs the extraction), theextract_this
column in the mainspecies
table will contain the integer primary key1
referencing the extractedextract_this
table. However, the assertion shows that the original string"This is extracted"
was inserted into theextract_this
column instead of the integer foreign key ID.
This failure confirms the suspicion from the test_extracts
failures: the extracts
parameter as implemented in insert
/upsert
/insert_all
/lookup
is incomplete. It sets up the foreign key column and constraint in the schema but fails to:
- Populate the extracted table with unique values.
- Replace the original values in the main table's column with the integer foreign key IDs.
- Create the unique index on the extracted table's value column (as shown by the
test_extracts
failures).
These steps are performed correctly by the Table.extract()
method, which works by first populating the lookup table and then running an UPDATE
query against the main table to replace the values. The current implementation of extracts
in the insert_all
flow tries to handle this during the initial insert, which is missing these crucial steps.
Suggested Fix:
The extracts
parameter should be fully supported by the insert
, upsert
, insert_all
, and lookup
methods. The most robust way to implement this is to adapt the logic used in Table.extract()
to the Table.insert_all
method, which is the core method called by the others.
Here's a high-level approach for fixing Table.insert_all
:
- Identify extracted columns: Before processing any data chunks, check if the
extracts
parameter is provided and parse the desired extracted columns and table names. - Pre-process records: Iterate through the input
records
before splitting them into chunks for batch insertion. - Perform lookup/insert for extracted values: For each record, for each column specified in
extracts
:- Get the original value from the record.
- If the value is not null/empty (similar behavior to
Table.extract
):- Use the
Database.table(...).lookup()
method on the target extracted table (e.g.,db["Species"]
) with the original value (e.g.,{"value": "Oak"}
). Thislookup
call will automatically create the extracted table if it doesn't exist, insert the unique value if it's not already present, create the unique index on thevalue
column (if creating the table or if it wasn't already unique), and return the integer primary key of the row in the extracted table. - Replace the original value in the current record dictionary with this integer ID obtained from the
lookup
.
- Use the
- Insert modified records: Pass this new sequence of records (where extracted columns now hold integer FK IDs) to the existing chunking and insertion logic within
insert_all
. This existing logic already handles schema creation, including foreign key constraints defined byextracts
(viaself.create
).
This approach ensures that the extracted tables are populated and indexed, and the main table receives the correct integer foreign key values.
Implementation Details (within Table.insert_all
):
def insert_all(
self,
records,
# ... other parameters ...
extracts=DEFAULT,
# ... rest of parameters ...
):
# ... parameter resolution ...
extracts_resolved = resolve_extracts(self.value_or_default("extracts", extracts))
# New: Pre-process records if extracts are used
if extracts_resolved:
def preprocessed_records_generator(records_iterator):
for record in records_iterator:
modified_record = dict(record)
for original_col, extracted_table_name in extracts_resolved.items():
original_value = modified_record.get(original_col)
# Decide how to handle None/empty strings - Table.extract seems to map NULL to NULL
if original_value is not None and original_value != "":
# Use lookup to get/create ID in the extracted table
extracted_table_obj = self.db[extracted_table_name]
try:
# Use lookup; ensure lookup creates unique index if needed
lookup_id = extracted_table_obj.lookup({"value": original_value})
except Exception as e:
# Handle potential errors during lookup (e.g. UNIQUE constraint violations if table exists but value isn't unique)
# For now, re-raise or log
print(f"Error during lookup for value '{original_value}' in table '{extracted_table_name}': {e}", file=sys.stderr)
raise
# Replace original value with the integer ID
modified_record[original_col] = lookup_id
# else: leave as None or original empty string
yield modified_record
records = preprocessed_records_generator(iter(records))
# Note: This consumes the original iterator. Subsequent checks like num_columns
# might need to peek again or adapt. Peeking at the first record might need
# to be done on the preprocessed_records generator.
# ... rest of the original insert_all logic ...
# including the chunking loop, table creation, and insert_chunk calls
# Ensure that the create_table call still includes the foreign_keys parameter derived from extracts_resolved
# The lookup calls within the generator handle creating the extracted table and its index
# The Table.create method already uses Database.create_table_sql which correctly
# includes the FK definition based on the (modified) columns dictionary and foreign_keys
# derived from extracts.
# After the loop, potentially re-analyze if analyze=True was passed
# The unique index on the extracted table's 'value' column should be created
# by the lookup() method itself if the table is created during the first lookup
# or if the value column wasn't already unique. Test this behavior of lookup.
# ... logic to set self.last_pk ...
# ... return self ...
Next Steps:
- Implement the pre-processing logic within
Table.insert_all
as outlined above. - Ensure that
Table.lookup
guarantees the unique index on thevalue
column if it creates the table or if the column is not already unique. (ReviewTable.lookup
implementation). - Re-run the tests, particularly
test_extracts.py
andtest_lookup.py
. - Address any new failures or edge cases that arise from this change (e.g., handling empty strings or
None
values correctly during extraction).
This fix should resolve the test_extracts
failures (by creating the index) and the test_lookup
failure (by replacing the string value with the integer ID).
189,073 input, 2,206 output, {"promptTokensDetails": [{"modality": "TEXT", "tokenCount": 189073}], "thoughtsTokenCount": 7210}
diff --git a/docs/python-api.rst b/docs/python-api.rst
index c6bf776..6893e1c 100644
--- a/docs/python-api.rst
+++ b/docs/python-api.rst
@@ -927,6 +927,13 @@ An upsert_all()
method is also available, which behaves like insert_all()` .. note::
.upsert()and
.upsert_all()in sqlite-utils 1.x worked like
.insert(..., replace=True)and
.insert_all(..., replace=True)`` do in 2.x. See issue #66 <https://github.com/simonw/sqlite-utils/issues/66>
__ for details of this change.
+.. _python_api_old_upsert:
+
+Alternative upserts using INSERT OR IGNORE
+------------------------------------------
+
+Upserts use INSERT INTO ... ON CONFLICT SET
. Prior to sqlite-utils 3.x
(TODO: fill in version) these used a sequence of INSERT OR IGNORE
followed by an UPDATE
. This older method is still used for SQLite 3.23.1 and earlier. You can force the older implementation by passing use_old_upsert=True
to the Database()
constructor.
+
.. _python_api_convert:
Converting data in columns
diff --git a/sqlite_utils/db.py b/sqlite_utils/db.py
index 144330a..89648e8 100644
--- a/sqlite_utils/db.py
+++ b/sqlite_utils/db.py
@@ -304,6 +304,8 @@ class Database:
sql, parameters
every time a SQL query is executed
:param use_counts_table: set to True
to use a cached counts table, if available. See
:ref:python_api_cached_table_counts
- :param use_old_upsert: set to
True
to force the older upsert implementation. See -
:param strict: Apply STRICT mode to all created tables (unless overridden) """:ref:`python_api_old_upsert`
@@ -320,10 +322,12 @@ class Database: tracer: Optional[Callable] = None, use_counts_table: bool = False, execute_plugins: bool = True,
-
): self.memory_name = None self.memory = Falseuse_old_upsert: bool = False, strict: bool = False,
-
self.use_old_upsert = use_old_upsert assert (filename_or_conn is not None and (not memory and not memory_name)) or ( filename_or_conn is None and (memory or memory_name) ), "Either specify a filename_or_conn or pass memory=True"
@@ -671,16 +675,45 @@ class Database: @property def supports_strict(self) -> bool: "Does this database support STRICT mode?"
-
try:
-
table_name = "t{}".format(secrets.token_hex(16))
-
with self.conn:
-
self.conn.execute(
-
"create table {} (name text) strict".format(table_name)
-
)
-
self.conn.execute("drop table {}".format(table_name))
-
return True
-
except Exception:
-
return False
-
if not hasattr(self, "_supports_strict"):
-
try:
-
table_name = "t{}".format(secrets.token_hex(16))
-
with self.conn:
-
self.conn.execute(
-
"create table {} (name text) strict".format(table_name)
-
)
-
self.conn.execute("drop table {}".format(table_name))
-
self._supports_strict = True
-
except Exception:
-
self._supports_strict = False
-
return self._supports_strict
-
@property
-
def supports_on_conflict(self) -> bool:
-
# SQLite's upsert is implemented as INSERT INTO ... ON CONFLICT DO ...
-
if not hasattr(self, "_supports_on_conflict"):
-
try:
-
table_name = "t{}".format(secrets.token_hex(16))
-
with self.conn:
-
self.conn.execute(
-
"create table {} (id integer primary key, name text)".format(
-
table_name
-
)
-
)
-
self.conn.execute(
-
"insert into {} (id, name) values (1, 'one')".format(table_name)
-
)
-
self.conn.execute(
-
(
-
"insert into {} (id, name) values (1, 'two') "
-
"on conflict do update set name = 'two'"
-
).format(table_name)
-
)
-
self.conn.execute("drop table {}".format(table_name))
-
self._supports_on_conflict = True
-
except Exception:
-
self._supports_on_conflict = False
-
return self._supports_on_conflict
@property def sqlite_version(self) -> Tuple[int, ...]: @@ -2966,102 +2999,125 @@ class Table(Queryable): replace, ignore, ):
-
# values is the list of insert data that is passed to the
-
# .execute() method - but some of them may be replaced by
-
# new primary keys if we are extracting any columns.
-
values = []
-
"""
-
Given a list ``chunk`` of records that should be written to *this* table,
-
return a list of ``(sql, parameters)`` 2-tuples which, when executed in
-
order, perform the desired INSERT / UPSERT / REPLACE operation.
-
""" if hash_id_columns and hash_id is None: hash_id = "id"
-
extracts = resolve_extracts(extracts)
-
# Build a row-list ready for executemany-style flattening
-
values: list[list] = [] for record in chunk:
-
record_values = []
-
for key in all_columns:
-
value = jsonify_if_needed(
-
record.get(
-
key,
-
(
-
None
-
if key != hash_id
-
else hash_record(record, hash_id_columns)
-
),
-
)
-
)
-
if key in extracts:
-
extract_table = extracts[key]
-
value = self.db[extract_table].lookup({"value": value})
-
record_values.append(value)
-
values.append(record_values)
-
queries_and_params = []
-
if upsert:
-
if isinstance(pk, str):
-
pks = [pk]
-
else:
-
pks = pk
-
self.last_pk = None
-
for record_values in values:
-
record = dict(zip(all_columns, record_values))
-
placeholders = list(pks)
-
# Need to populate not-null columns too, or INSERT OR IGNORE ignores
-
# them since it ignores the resulting integrity errors
-
if not_null:
-
placeholders.extend(not_null)
-
sql = "INSERT OR IGNORE INTO [{table}]({cols}) VALUES({placeholders});".format(
-
table=self.name,
-
cols=", ".join(["[{}]".format(p) for p in placeholders]),
-
placeholders=", ".join(["?" for p in placeholders]),
-
)
-
queries_and_params.append(
-
(sql, [record[col] for col in pks] + ["" for _ in (not_null or [])])
-
)
-
# UPDATE [book] SET [name] = 'Programming' WHERE [id] = 1001;
-
set_cols = [col for col in all_columns if col not in pks]
-
if set_cols:
-
sql2 = "UPDATE [{table}] SET {pairs} WHERE {wheres}".format(
-
table=self.name,
-
pairs=", ".join(
-
"[{}] = {}".format(col, conversions.get(col, "?"))
-
for col in set_cols
-
),
-
wheres=" AND ".join("[{}] = ?".format(pk) for pk in pks),
-
)
-
queries_and_params.append(
-
(
-
sql2,
-
[record[col] for col in set_cols]
-
+ [record[pk] for pk in pks],
-
)
-
)
-
# We can populate .last_pk right here
-
if num_records_processed == 1:
-
self.last_pk = tuple(record[pk] for pk in pks)
-
if len(self.last_pk) == 1:
-
self.last_pk = self.last_pk[0]
-
row_vals = []
-
for col in all_columns:
-
if col == hash_id:
-
row_vals.append(hash_record(record, hash_id_columns))
-
continue
-
else:
-
or_what = ""
-
if replace:
-
or_what = "OR REPLACE "
-
elif ignore:
-
or_what = "OR IGNORE "
-
sql = """
-
INSERT {or_what}INTO [{table}] ({columns}) VALUES {rows};
-
""".strip().format(
-
or_what=or_what,
-
table=self.name,
-
columns=", ".join("[{}]".format(c) for c in all_columns),
-
rows=", ".join(
-
"({placeholders})".format(
-
placeholders=", ".join(
-
[conversions.get(col, "?") for col in all_columns]
-
val = record.get(col)
-
if val is None and not_null and col in not_null:
-
val = ""
-
row_vals.append(jsonify_if_needed(val))
-
values.append(row_vals)
-
columns_sql = ", ".join(f"[{c}]" for c in all_columns)
-
placeholder_expr = ", ".join(conversions.get(c, "?") for c in all_columns)
-
row_placeholders_sql = ", ".join(f"({placeholder_expr})" for _ in values)
-
flat_params = list(itertools.chain.from_iterable(values))
-
# replace=True mean INSERT OR REPLACE INTO
-
if replace:
-
sql = (
-
f"INSERT OR REPLACE INTO [{self.name}] "
-
f"({columns_sql}) VALUES {row_placeholders_sql}"
-
)
-
return [(sql, flat_params)]
-
# If not an upsert it's an INSERT, maybe with OR IGNORE
-
if not upsert:
-
or_ignore = ""
-
if ignore:
-
or_ignore = " OR IGNORE"
-
sql = (
-
f"INSERT{or_ignore} INTO [{self.name}] "
-
f"({columns_sql}) VALUES {row_placeholders_sql}"
-
)
-
return [(sql, flat_params)]
-
# Everything from here on is for upsert=True
-
pk_cols = [pk] if isinstance(pk, str) else list(pk)
-
non_pk_cols = [c for c in all_columns if c not in pk_cols]
-
conflict_sql = ", ".join(f"[{c}]" for c in pk_cols)
-
if self.db.supports_on_conflict and not self.db.use_old_upsert:
-
if non_pk_cols:
-
# DO UPDATE
-
assignments = []
-
for c in non_pk_cols:
-
if c in conversions:
-
assignments.append(
-
f"[{c}] = {conversions[c].replace('?', f'excluded.[{c}]')}" )
-
)
-
for record in chunk
-
),
-
else:
-
assignments.append(f"[{c}] = excluded.[{c}]")
-
do_clause = "DO UPDATE SET " + ", ".join(assignments)
-
else:
-
# All columns are in the PK – nothing to update.
-
do_clause = "DO NOTHING"
-
sql = (
-
f"INSERT INTO [{self.name}] ({columns_sql}) "
-
f"VALUES {row_placeholders_sql} "
-
f"ON CONFLICT({conflict_sql}) {do_clause}" )
-
flat_values = list(itertools.chain(*values))
-
queries_and_params = [(sql, flat_values)]
-
return [(sql, flat_params)]
-
# At this point we need compatibility UPSERT for SQLite < 3.24.0
-
# (INSERT OR IGNORE + second UPDATE stage)
-
queries_and_params: list[tuple[str, list]] = []
-
insert_sql = (
-
f"INSERT OR IGNORE INTO [{self.name}] "
-
f"({columns_sql}) VALUES {row_placeholders_sql}"
-
)
-
queries_and_params.append((insert_sql, flat_params))
-
# If there is nothing to update we are done.
-
if not non_pk_cols:
-
return queries_and_params
-
# We can use UPDATE … FROM (VALUES …) on SQLite ≥ 3.33.0
-
# Older SQLite versions will run this as one UPDATE per row
-
# – which is what sqlite-utils did prior to this refactor.
-
alias_cols_sql = ", ".join(pk_cols + non_pk_cols)
-
assignments = []
-
for c in non_pk_cols:
-
if c in conversions:
-
assignments.append(f"[{c}] = {conversions[c].replace('?', f'v.[{c}]')}")
-
else:
-
assignments.append(f"[{c}] = v.[{c}]")
-
assignments_sql = ", ".join(assignments)
-
update_sql = (
-
f"UPDATE [{self.name}] AS m SET {assignments_sql} "
-
f"FROM (VALUES {row_placeholders_sql}) "
-
f"AS v({alias_cols_sql}) "
-
f"WHERE " + " AND ".join(f"m.[{c}] = v.[{c}]" for c in pk_cols)
-
)
-
# Parameters for the UPDATE – pk cols first then non-pk cols
-
update_params: list = []
-
for row in values:
-
row_dict = dict(zip(all_columns, row))
-
ordered = [row_dict[c] for c in pk_cols + non_pk_cols]
-
update_params.extend(ordered)
-
queries_and_params.append((update_sql, update_params)) return queries_and_params
def insert_chunk( @@ -3079,7 +3135,7 @@ class Table(Queryable): num_records_processed, replace, ignore,
- ):
- ) -> Optional[sqlite3.Cursor]: queries_and_params = self.build_insert_queries_and_params( extracts, chunk, @@ -3094,9 +3150,8 @@ class Table(Queryable): replace, ignore, )
-
result = None with self.db.conn:
-
result = None for query, params in queries_and_params: try: result = self.db.execute(query, params)
@@ -3125,7 +3180,7 @@ class Table(Queryable): ignore, )
-
self.insert_chunk(
-
result = self.insert_chunk( alter, extracts, second_half,
@@ -3143,20 +3198,7 @@ class Table(Queryable):
else:
raise
-
if num_records_processed == 1 and not upsert:
-
self.last_rowid = result.lastrowid
-
self.last_pk = self.last_rowid
-
# self.last_rowid will be 0 if a "INSERT OR IGNORE" happened
-
if (hash_id or pk) and self.last_rowid:
-
row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0]
-
if hash_id:
-
self.last_pk = row[hash_id]
-
elif isinstance(pk, str):
-
self.last_pk = row[pk]
-
else:
-
self.last_pk = tuple(row[p] for p in pk)
-
return
-
return result
def insert( self, @@ -3276,6 +3318,7 @@ class Table(Queryable):
if upsert and (not pk and not hash_id): raise PrimaryKeyRequired("upsert() requires a pk")
-
assert not (hash_id and pk), "Use either pk= or hash_id=" if hash_id_columns and (hash_id is None): hash_id = "id"
@@ -3307,6 +3350,7 @@ class Table(Queryable): self.last_pk = None if truncate and self.exists(): self.db.execute("DELETE FROM [{}];".format(self.name))
-
result = None for chunk in chunks(itertools.chain([first_record], records), batch_size): chunk = list(chunk) num_records_processed += len(chunk)
@@ -3314,6 +3358,12 @@ class Table(Queryable): if not self.exists(): # Use the first batch to derive the table names column_types = suggest_column_types(chunk)
-
if extracts:
-
for col in extracts:
-
if col in column_types:
-
column_types[col] = (
-
int # This will be an integer foreign key
-
) column_types.update(columns or {}) self.create( column_types,
@@ -3341,7 +3391,7 @@ class Table(Queryable):
first = False
-
self.insert_chunk(
-
result = self.insert_chunk( alter, extracts, chunk,
@@ -3357,6 +3407,33 @@ class Table(Queryable): ignore, )
-
# If we only handled a single row populate self.last_pk
-
if num_records_processed == 1:
-
# For an insert we need to use result.lastrowid
-
if not upsert:
-
self.last_rowid = result.lastrowid
-
if (hash_id or pk) and self.last_rowid:
-
# Set self.last_pk to the pk(s) for that rowid
-
row = list(self.rows_where("rowid = ?", [self.last_rowid]))[0]
-
if hash_id:
-
self.last_pk = row[hash_id]
-
elif isinstance(pk, str):
-
self.last_pk = row[pk]
-
else:
-
self.last_pk = tuple(row[p] for p in pk)
-
else:
-
self.last_pk = self.last_rowid
-
else:
-
# For an upsert use first_record from earlier
-
if hash_id:
-
self.last_pk = hash_record(first_record, hash_id_columns)
-
else:
-
self.last_pk = (
-
first_record[pk]
-
if isinstance(pk, str)
-
else tuple(first_record[p] for p in pk)
-
)
-
if analyze: self.analyze()
diff --git a/tests/test_cli.py b/tests/test_cli.py index 4033af6..88763ee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1117,9 +1117,8 @@ def test_upsert_alter(db_path, tmpdir): ) assert result.exit_code == 1 assert (
-
"Error: no such column: age\n\n"
-
"sql = UPDATE [dogs] SET [age] = ? WHERE [id] = ?\n"
-
"parameters = [5, 1]"
-
"Error: table dogs has no column named age\n\n"
-
) == result.output.strip() result = CliRunner().invoke( @@ -2248,7 +2247,7 @@ def test_integer_overflow_error(tmpdir): assert result.exit_code == 1 assert result.output == ( "Error: Python int too large to convert to SQLite INTEGER\n\n""Try using --alter to add additional columns"
-
"sql = INSERT INTO [items] ([bignumber]) VALUES (?);\n"
-
)"sql = INSERT INTO [items] ([bignumber]) VALUES (?)\n" "parameters = [34223049823094832094802398430298048240]\n"
diff --git a/tests/test_tracer.py b/tests/test_tracer.py index 26318ae..9dfb490 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -18,7 +18,7 @@ def test_tracer(): ("select name from sqlite_master where type = 'view'", None), ("CREATE TABLE [dogs] (\n [name] TEXT\n);\n ", None), ("select name from sqlite_master where type = 'view'", None),
-
("INSERT INTO [dogs] ([name]) VALUES (?);", ["Cleopaws"]),
-
("INSERT INTO [dogs] ([name]) VALUES (?)", ["Cleopaws"]), ("select name from sqlite_master where type = 'view'", None), ( "CREATE VIRTUAL TABLE [dogs_fts] USING FTS5 (\n [name],\n content=[dogs]\n)",
These are the changes I made which caused the tests to fail
Okay, let's analyze the failures based on the provided tracebacks and the code diff.
Understanding the Failures:
-
tests/test_extracts.py::test_extracts[...]
(6 failures):- The core assertion failing is
assert [...] == fresh_db[expected_table].indexes
. - The assertion expects a list containing one
Index
object for the extracted table (e.g.,Species
orspecies_id
), specificallyIndex(..., name="idx_<table_name>_value", unique=1, ..., columns=["value"])
. - The test output shows the actual result is
[]
- an empty list. - Conclusion: The code under test (
Table.insert_all
when passedextracts
) is failing to create the required unique index on thevalue
column of the newly created extracted lookup table.
- The core assertion failing is
-
tests/test_lookup.py::test_lookup_with_extra_insert_parameters
(1 failure):- The assertion failing is
assert species.get(id) == { ... }
. - The test expected the
extract_this
column in the mainspecies
table to hold the integer primary key of the corresponding row in the extractedextract_this
table (expected{'extract_this': 1}
). - The actual result showed the original string value was retained (
{'extract_this': 'This is extracted'}
). - Conclusion: The code under test (
Table.lookup
when passedextracts
) is failing to replace the original string value in the main table's column with the integer foreign key ID from the extracted lookup table.
- The assertion failing is
Both sets of failures point to an incomplete implementation of the extracts
feature when used via the insert
, upsert
, insert_all
, or lookup
methods. The extracts
feature is designed to:
- Create a separate lookup table (if it doesn't exist).
- Populate that lookup table with unique values from the original column.
- Add a unique index on the "value" column of the lookup table.
- Replace the values in the original column in the main table with the integer foreign key IDs from the lookup table.
- Add a foreign key constraint to the main table's column.
The test failures indicate that steps 2, 3, and 4 are not happening correctly within the insert_all
flow (which insert
, upsert
, and lookup
all call).
Analyzing Your Diff:
Your diff introduces changes related to the ON CONFLICT
upsert syntax (for newer SQLite versions) and improves error reporting in insert_chunk
by capturing the query
and params
. These seem like positive changes.
However, your diff also modifies the build_insert_queries_and_params
method in sqlite_utils/db.py
. Specifically, you've removed the block that attempted to handle the extracts
lookup:
- if key in extracts:
- extract_table = extracts[key]
- value = self.db[extract_table].lookup({"value": value})
- record_values.append(value)
- values.append(record_values)
+ row_vals = []
+ for col in all_columns:
+ if col == hash_id:
+ row_vals.append(hash_record(record, hash_id_columns))
+ continue
+
+ val = record.get(col)
+ if val is None and not_null and col in not_null:
+ val = ""
+ row_vals.append(jsonify_if_needed(val)) # <-- This now appends the original value directly
+ values.append(row_vals)
This change completely removes the code that was supposed to perform the lookup
and replace the value with the ID. While the original code in that block had a subtle bug (it calculated value = ...lookup(...)
but then appended the original value
from the outer scope to record_values
), removing it means the extraction logic is no longer attempted during insertion at all.
The Fix:
The fix requires restoring and correctly implementing the extracts
logic within the insert_all
flow. The best place to do this is before the records are grouped into chunks for the final SQL INSERT statement. This pre-processing step can iterate through the input records, perform the necessary lookups/inserts into the extracted tables, and modify the records to contain the integer foreign key IDs before they are passed to the batch insertion logic.
Here's how you could modify the insert_all
method:
def insert_all(
self,
records,
pk=DEFAULT,
foreign_keys=DEFAULT,
column_order=DEFAULT,
not_null=DEFAULT,
defaults=DEFAULT,
batch_size=DEFAULT,
hash_id=DEFAULT,
hash_id_columns=DEFAULT,
alter=DEFAULT,
ignore=DEFAULT,
replace=DEFAULT,
truncate=False,
extracts=DEFAULT,
conversions=DEFAULT,
columns=DEFAULT,
upsert=False,
analyze=False,
strict=DEFAULT,
) -> "Table":
# ... (existing parameter resolution) ...
extracts_resolved = resolve_extracts(self.value_or_default("extracts", extracts))
conversions_resolved = self.value_or_default("conversions", conversions) or {}
if hash_id_columns and hash_id is None:
hash_id = "id"
if hash_id:
pk = hash_id # Set pk if hash_id is used
if upsert and not pk: # Check pk after hash_id resolution
raise PrimaryKeyRequired("upsert() requires a pk")
# --- Start of New/Modified Extraction Logic ---
def process_records_with_extracts(records_iterator):
for record in records_iterator:
modified_record = dict(record) # Work on a copy
for original_col, extracted_table_name in extracts_resolved.items():
original_value = modified_record.get(original_col)
# Handle None/empty strings like Table.extract does (map to None)
if original_value is not None and original_value != "":
# Use lookup to get/create ID in the extracted table
# This handles table creation, unique index, and value insertion
extracted_table_obj = self.db[extracted_table_name]
# Lookup on the 'value' column is the standard behavior for extracts
lookup_id = extracted_table_obj.lookup({"value": original_value})
# Replace original value with the integer ID
modified_record[original_col] = lookup_id
else:
# Ensure it's None if it was an empty string or None
modified_record[original_col] = None
yield modified_record
if extracts_resolved:
# Process records through the extraction generator
records = process_records_with_extracts(iter(records))
else:
# If no extracts, just use the original iterator
records = iter(records)
# --- End of New/Modified Extraction Logic ---
# (existing code continues from here, operating on the potentially modified 'records' iterator)
# Fix up any records with square braces in the column names - must happen AFTER extracts processing
records = fix_square_braces(records)
# Peek at first record to count its columns:
try:
# Peeking must happen AFTER preprocessing
records, peek_iter = itertools.tee(records, 2)
first_record = next(peek_iter)
# If we consumed the first record from peek_iter, we need to chain it back
records = itertools.chain([first_record], records)
except StopIteration:
return self # It was an empty list
num_records_processed = 0 # Reset counter for the main insertion loop
num_columns = len(first_record.keys())
# ... (rest of batch size calculation and truncate logic) ...
self.last_rowid = None
self.last_pk = None
if truncate and self.exists():
self.db.execute("DELETE FROM [{}];".format(self.name))
result = None # Initialize result for last_rowid/pk
for chunk in chunks(records, batch_size):
chunk = list(chunk) # Convert chunk iterator to list
num_records_processed += len(chunk)
# This 'first' block needs adjustment. Schema should be based on types *after* extraction lookup.
# The Table.create call below will handle this correctly IF the 'columns' dictionary
# reflects the types needed after extraction (which it does if built from the modified chunk).
# The 'all_columns' list also needs to be built from the keys of the modified records.
# Re-derive column types and all_columns from the current chunk
if not self.exists():
column_types = suggest_column_types(chunk)
# No need to special-case extracts here, the values in chunk are already IDs
# column_types for the extracted columns should now suggest 'int'
column_types.update(columns or {}) # Apply explicit column type overrides
self.create(
column_types,
pk,
foreign_keys=extracts_resolved, # Pass extracts_resolved as foreign_keys for schema creation
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
# No need for extracts=extracts_resolved here anymore - the schema is built
# from column_types and foreign_keys now.
strict=strict,
)
# Ensure all_columns is built from keys present in the modified chunk
all_columns_set = set()
for record in chunk:
all_columns_set.update(record.keys())
all_columns = list(sorted(all_columns_set))
if hash_id and hash_id not in all_columns: # Ensure hash_id column is in all_columns if used
all_columns.insert(0, hash_id)
# Now build and execute queries for this chunk using build_insert_queries_and_params
# This method no longer needs to handle the extracts lookup itself.
queries_and_params = self.build_insert_queries_and_params(
extracts=None, # Do NOT pass extracts here, it's handled already
chunk=chunk,
all_columns=all_columns,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
upsert=upsert,
pk=pk,
not_null=not_null,
conversions=conversions_resolved, # Use resolved conversions
num_records_processed=num_records_processed, # Keep this for last_pk logic? Or simplify?
replace=replace,
ignore=ignore,
)
with self.db.conn:
chunk_result = None # Track result for this chunk for last_rowid
for query, params in queries_and_params:
try:
chunk_result = self.db.execute(query, params)
except OperationalError as e:
# Handle alter and too many variables (from your diff)
if alter and (" column" in e.args[0]):
# Attempt to add any missing columns (based on modified chunk), then try again
self.add_missing_columns(chunk)
chunk_result = self.db.execute(query, params)
elif e.args[0] == "too many SQL variables":
# Recursively process smaller chunks - this needs careful parameter passing
# It's better to handle this outside the build_insert_queries_and_params loop
# Or even better, ensure build_insert_queries_and_params splits big lists
# Let's simplify and assume chunk size < SQLITE_MAX_VARS / num_cols prevents this.
# Your diff already handled this recursive call logic in insert_chunk - let's keep that structure.
# So move this try/except structure back to insert_chunk.
# For now, keep the basic error handling from your diff
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(str(e), query, params)
)
else:
raise click.ClickException( # Keep this for other OperationalErrors
"{}\n\nsql = {}\nparameters = {}".format(str(e), query, params)
)
except Exception as e: # Catch any other exceptions (from your diff)
raise click.ClickException(
"{}\n\nsql = {}\nparameters = {}".format(str(e), query, params)
)
result = chunk_result # Store the result of the last query in the chunk
# (existing logic to set self.last_pk based on the final 'result')
# This logic should use the 'result' variable from the loop
# (existing analyze logic)
return self
# Move the error handling structure you added back into insert_chunk,
# as it correctly handles the "too many SQL variables" recursion there.
# The build_insert_queries_and_params method should NOT have the try/except.
# The core fixes are:
# 1. Add a pre-processing step in insert_all to iterate records and perform lookups for extracts.
# 2. Modify the records in place (or yield new ones) with the integer FK IDs.
# 3. Pass the modified records to the chunking/insertion logic.
# 4. Ensure Table.create is called with foreign_keys derived from the original extracts spec.
# 5. Remove the incorrect/incomplete extracts handling from build_insert_queries_and_params.
# 6. Keep the improved error handling structure (query/params in ClickException) from your diff, applying it to the db.execute calls.
Summary of Changes Needed:
- Modify
Table.insert_all
:- Resolve
extracts
andconversions
. - Add a generator/list comprehension that iterates through the input
records
, performs thelookup
for each value in columns marked forextracts
, and replaces the string value with the returned integer ID in a copy of the record. HandleNone
/empty strings ->None
. - Replace the original
records
iterator with this new iterator/list of modified records. - Ensure
fix_square_braces
and peeking (itertools.tee
) happen after the extracts preprocessing. - When calling
self.create
(if the table doesn't exist), pass theforeign_keys=extracts_resolved
(or merge it with any other explicitforeign_keys
parameter) so the FK constraint is added to the schema. - Pass the resolved
conversions
(from the Table constructor defaults or kwargs) tobuild_insert_queries_and_params
. - Pass
extracts=None
tobuild_insert_queries_and_params
, as extracts are handled in the pre-processing step.
- Resolve
- Modify
Table.build_insert_queries_and_params
:- Remove the block that checked
if key in extracts
and attemptedself.db[extract_table].lookup({"value": value})
. This logic is now ininsert_all
. - It should now expect that extracted columns in the
chunk
already contain integer IDs (orNone
).
- Remove the block that checked
- Modify
Table.insert_chunk
:- Keep the error handling structure you added, catching
OperationalError
andException
, and raisingClickException
withquery
andparams
. This correctly reports the SQL that failed, including the integer overflow case.
- Keep the error handling structure you added, catching
By making these changes, the insert_all
process will correctly:
- Create and populate the extracted lookup tables.
- Add the unique index to the lookup table's 'value' column (handled by
lookup
). - Store the correct integer foreign key IDs in the main table.
- Add the foreign key constraint to the main table schema (handled by
create
). - Preserve the improved error reporting from your diff.
Apply these logic changes, and the tests should pass.
196,977 input, 3,789 output, {"promptTokensDetails": [{"modality": "TEXT", "tokenCount": 196977}], "thoughtsTokenCount": 4246}
that seems WAY more complicated than I thought, what is the minimum change to fix this?
Okay, you're right, let's find the absolute minimum change to get those specific tests passing based on the errors shown and the code provided.
The tests are failing because:
- The extracted table is not getting the required unique index on its
value
column (test_extracts
failures). - The original string value in the main table column is not being replaced with the integer foreign key ID (
test_lookup_with_extra_insert_parameters
failure).
The code you removed from build_insert_queries_and_params
(if key in extracts: ... lookup(...)
) was the attempt to do this replacement, even though it was flawed in its original implementation.
The core fix needs to involve: a) Populating the extracted table. b) Getting the generated integer IDs. c) Updating the original table column with those IDs. d) Ensuring the unique index exists on the extracted table's value column.
The simplest way to achieve (a), (b), and (d) together is to use the lookup()
method for each unique value found in the original column. lookup()
is designed exactly for this: it finds or creates a row based on provided values, adds a unique index if necessary, and returns the primary key (the integer ID).
The simplest way to achieve (c) is to run an UPDATE
query after the initial data has been inserted into the main table.
Here's the minimum change approach:
- Keep your
build_insert_queries_and_params
andinsert_chunk
changes. They are good for the upsert logic and error reporting. Thebuild_insert_queries_and_params
method should not try to do the extracts lookup internally anymore. - Add a post-processing step at the end of
Table.insert_all
. After all records have been inserted into the main table (with their original string values in the columns intended for extraction), iterate through the unique values in those columns, perform alookup
on the target extracted table for each unique value, and then run anUPDATE
on the main table to replace the original strings with the new integer IDs.
This is still not trivial, but it avoids deeply embedding the extraction logic into the high-performance batch insert SQL generation.
Let's put the logic at the end of Table.insert_all
:
def insert_all(
self,
records,
# ... (other parameters) ...
extracts=DEFAULT, # Keep this
# ... (rest of parameters) ...
) -> "Table":
# ... (Existing parameter resolution) ...
# Resolve extracts parameter BEFORE processing records
extracts_resolved = resolve_extracts(self.value_or_default("extracts", extracts))
# (Keep your existing logic for hash_id, pk, upsert checks)
if hash_id_columns and hash_id is None:
hash_id = "id"
if hash_id:
pk = hash_id
if upsert and not pk:
raise PrimaryKeyRequired("upsert() requires a pk")
assert not (hash_id and pk), "Use either pk= or hash_id="
# --- Start of Core Insertion Logic (mostly unchanged from your diff) ---
# This part inserts the records with original string values in extract columns
all_columns = []
first = True
num_records_processed = 0
# Fix up any records with square braces - must happen before peeking/processing
records = fix_square_braces(records)
records = iter(records) # Ensure it's an iterator
# Peek at first record to count its columns:
try:
records, peek_iter = itertools.tee(records, 2)
first_record = next(peek_iter)
records = itertools.chain([first_record], records) # Chain it back
except StopIteration:
return self # It was an empty list
num_columns = len(first_record.keys())
assert (
num_columns <= SQLITE_MAX_VARS
), "Rows can have a maximum of {} columns".format(SQLITE_MAX_VARS)
batch_size = max(1, min(batch_size, SQLITE_MAX_VARS // num_columns))
self.last_rowid = None
self.last_pk = None
if truncate and self.exists():
self.db.execute("DELETE FROM [{}];".format(self.name))
# Process chunks and insert them - this uses the original string values
# in the columns designated for extraction
result = None # Keep track of the last cursor result
for chunk in chunks(records, batch_size):
chunk = list(chunk)
num_records_processed += len(chunk)
# When creating the table (if it doesn't exist),
# ensure FKs from extracts are included in the schema definition.
# The `create` method does this IF foreign_keys param is passed.
# extracts_resolved IS the foreign_keys structure needed for schema.
if first and not self.exists():
column_types = suggest_column_types(chunk)
# Do NOT change column_types to int here - they are strings initially!
column_types.update(columns or {})
self.create(
column_types,
pk,
# Pass extracts_resolved as foreign_keys for schema creation
foreign_keys=extracts_resolved,
column_order=column_order,
not_null=not_null,
defaults=defaults,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
# extracts parameter is not needed for table creation anymore
strict=strict,
)
# Ensure all_columns reflects all keys seen so far across chunks
# and includes the hash_id column if applicable.
all_columns_set = set()
for record in chunk:
all_columns_set.update(record.keys())
all_columns = list(sorted(all_columns_set))
if hash_id and hash_id not in all_columns:
all_columns.insert(0, hash_id)
# Use the insert_chunk method from your diff
# Pass extracts=None because the *lookup* logic is not happening here
chunk_result = self.insert_chunk(
alter=alter,
extracts=None, # Extracts are handled *after* this insert
chunk=chunk,
all_columns=all_columns,
hash_id=hash_id,
hash_id_columns=hash_id_columns,
upsert=upsert,
pk=pk,
not_null=not_null,
conversions=conversions, # Pass conversions if any
num_records_processed=num_records_processed, # Keep for last_pk logic?
replace=replace,
ignore=ignore,
)
# Keep the result of the last chunk inserted
if chunk_result is not None:
result = chunk_result
first = False # After the first chunk
# --- End of Core Insertion Logic ---
# --- Start of NEW Post-Processing Extraction Logic ---
if extracts_resolved:
with self.db.conn: # Ensure these happen in a single transaction
for original_col, extracted_table_name in extracts_resolved.items():
extracted_table_obj = self.db[extracted_table_name]
# 1. Populate extracted table and ensure index exists using lookup
# We need unique values from the original column in the main table
# SELECT DISTINCT value FROM main_table WHERE value IS NOT NULL AND value != ''
distinct_values = self.db.execute(
f"SELECT DISTINCT [{original_col}] FROM [{self.name}] WHERE [{original_col}] IS NOT NULL AND [{original_col}] != ''"
).fetchall()
# Use lookup for each unique value
for row in distinct_values:
original_value = row[0]
# lookup handles creating the table (if not created already by FK),
# inserting the unique value if new, and adding the unique index
extracted_table_obj.lookup({"value": original_value})
# 2. Update the original table column with integer IDs
# UPDATE main_table SET original_col = (SELECT id FROM extracted_table WHERE value = main_table.original_col)
update_sql = f"""
UPDATE [{self.name}]
SET [{original_col}] = (
SELECT id FROM [{extracted_table_name}]
WHERE [{extracted_table_name}].[value] = [{self.name}].[{original_col}]
)
WHERE [{original_col}] IS NOT NULL AND [{original_col}] != ''
"""
self.db.execute(update_sql)
# Ensure the foreign key constraint is added. This is handled by the
# `self.create()` call earlier if the table was newly created.
# If the table existed and was altered, the foreign_keys might need
# to be added via `transform`. For the minimum change, we rely on
# the `create` or `alter` logic handling FKs based on the `foreign_keys`
# parameter passed to `create`. When `extracts` is used, `foreign_keys`
# passed to `create` should include the FKs resulting from extraction.
# --- End of NEW Post-Processing Extraction Logic ---
# Set self.last_pk based on the last inserted/upserted record
if num_records_processed == 1:
if not upsert:
# last_rowid set by execute in insert_chunk
self.last_rowid = result.lastrowid if result else None
if (hash_id or pk) and self.last_rowid is not None and self.last_rowid > 0:
# Re-fetch the single row by rowid to get its actual PK(s)
row = list(self.rows_where("rowid = ?", [self.last_rowid]))
if row:
row = row[0]
if hash_id:
self.last_pk = row[hash_id]
elif isinstance(pk, str):
self.last_pk = row[pk]
else: # Compound PK
self.last_pk = tuple(row[p] for p in pk)
# Ensure last_pk is not None if rowid was > 0 and row found
if self.last_pk is None:
# This case is unlikely if rowid > 0 and row found, but defensive
self.last_pk = self.last_rowid # Fallback to rowid if PK columns were null
else:
# Should not happen if result.lastrowid > 0, but defensive
self.last_pk = None
else:
# No hash_id/pk used, or insert OR IGNORE ignored it (lastrowid = 0)
# If last_rowid is 0, it means no insert happened, last_pk should be None
self.last_pk = self.last_rowid if (self.last_rowid is not None and self.last_rowid > 0) else None
else: # Upsert
# For a single-row upsert, the PK is known from the input record
# (assuming pk or hash_id was used, which is required for upsert)
# The first_record from the peek should be available here
if hash_id:
self.last_pk = hash_record(first_record, hash_id_columns)
else: # pk was used
self.last_pk = (
first_record[pk]
if isinstance(pk, str)
else tuple(first_record[p] for p in pk)
)
else:
self.last_rowid = None
self.last_pk = None
if analyze:
self.analyze()
return self
Explanation of Minimum Changes:
- Moved Extraction Logic: The code that performs the
lookup
and theUPDATE
is moved outside theinsert_chunk
method and placed after the main loop that processes all chunks. This means the original string values are inserted first, and then a separate process updates them. - Used
lookup()
: The post-processing iterates throughSELECT DISTINCT
values from the original column. For each value, it callsextracted_table_obj.lookup({"value": original_value})
. This method already contains the logic to:- Create the extracted table if it doesn't exist.
- Insert the unique value into the extracted table if it's not already there.
- Crucially, it ensures a unique index is on the
value
column of the lookup table (fixingtest_extracts
). - Return the integer ID of the unique value in the lookup table.
- Added
UPDATE
Query: A separateUPDATE
statement is executed on the main table (self.name
) to replace the string values in theoriginal_col
with the integer IDs from the extracted table (extracted_table_name
). The subquery joins the main table to the extracted table on the original value to find the correct ID. - Schema/FKs: The foreign key constraint in the main table schema is correctly handled by the
self.create()
call (orself.add_missing_columns
withinalter=True
) because whenextracts
is used withinsert
/create
, theforeign_keys
parameter for schema generation includes the definition for the extracted columns pointing to the target tables.
This approach is the minimum to get the tests passing by implementing the necessary data transformations (populating lookup, updating main table) and ensuring the correct index exists. It relies heavily on the existing lookup
method's functionality.
Apply this updated insert_all
logic and keep your other diffs (like the better error reporting in insert_chunk
). The tests should now pass.
200,813 input, 3,231 output, {"promptTokensDetails": [{"modality": "TEXT", "tokenCount": 200813}], "thoughtsTokenCount": 1294}
How about the simplest way to put back my initial flawed implementation?
Okay, if you want the absolute minimum change to restore the state just before your diff caused those specific tests to fail, focusing only on the extracts
part in build_insert_queries_and_params
that was removed:
You need to put back the attempt to call lookup
inside the loop that builds the values
list, right before the jsonify_if_needed
call for that value.
Here's the specific snippet you need to re-insert into the build_insert_queries_and_params
method in sqlite_utils/db.py
:
Locate the for col in all_columns:
loop.
Find the lines where you get the value val
and handle the not_null
case:
val = record.get(col)
if val is None and not_null and col in not_null:
val = ""
Immediately after these lines, but before row_vals.append(jsonify_if_needed(val))
, insert the following if
block:
# --- START Re-inserting the flawed extraction logic ---
if col in extracts:
extract_table = extracts[col]
# The original code attempted to call lookup here,
# overwriting the 'val' variable *before* final jsonify_if_needed
val = self.db[extract_table].lookup({"value": val})
# --- END Re-inserting the flawed extraction logic ---
row_vals.append(jsonify_if_needed(val)) # This line was already there
So the structure of the loop body will look like this:
def build_insert_queries_and_params(...):
# ... other setup ...
extracts = resolve_extracts(extracts) # Keep this line or ensure extracts_resolved is available
values: list[list] = []
for record in chunk:
row_vals = []
for col in all_columns:
if col == hash_id:
row_vals.append(hash_record(record, hash_id_columns))
continue
val = record.get(col)
if val is None and not_null and col in not_null:
val = ""
# --- Re-inserted flawed extraction logic ---
if col in extracts: # Check if this column needs extraction
extract_table = extracts[col] # Get the target table name
# Attempt to lookup/create the value in the extracted table
# This calls lookup but doesn't perform the full extraction process
val = self.db[extract_table].lookup({"value": val})
# --- End re-inserted logic ---
# Append the value (original string or attempted integer ID from lookup)
# after applying necessary JSONification for non-extracted types.
row_vals.append(jsonify_if_needed(val))
values.append(row_vals)
# ... rest of the method using the 'values' list ...
Important: This is putting back the flawed implementation. The tests will likely fail in the exact same way they did before your diff (or very close to it), because:
- The
lookup()
call inside this tightexecutemany
parameter generation loop isn't the correct place to do the full extraction process (creating the index, handling the final update step). - Specifically, this code does not ensure the unique index exists on the extracted table's
value
column, causing thetest_extracts
failure. - While
lookup
returns the integer ID, the interaction withjsonify_if_needed
and the overall batch insert flow might not correctly result in the integer ID being stored in the main table, or the testtest_lookup
might rely on the index existing for thelookup
call within the test itself to behave as expected.
This only serves to restore the code to the state where the failures you reported occurred. It doesn't fix the underlying issue with the extracts
implementation in the insert_all
flow.
204,082 input, 919 output, {"cachedContentTokenCount": 200658, "promptTokensDetails": [{"modality": "TEXT", "tokenCount": 204082}], "cacheTokensDetails": [{"modality": "TEXT", "tokenCount": 200658}], "thoughtsTokenCount": 2326}