Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

515

516

517

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

541

542

543

544

545

546

547

548

549

550

551

552

553

554

555

556

557

558

559

560

561

562

563

564

565

566

# 

# BSD licence 

# 

# Author : Pierre Quentel (pierre.quentel@gmail.com) 

# 

 

 

""" 

Main differences from :mod:`pydblite.pydblite`: 

 

- pass the connection to the :class:`SQLite db <pydblite.sqlite.Database>` as argument to 

  :class:`Table <pydblite.sqlite.Table>` 

- in :func:`create() <pydblite.sqlite.Table.create>` field definitions must specify a type. 

- no `drop_field` (not supported by SQLite) 

- the :class:`Table <pydblite.sqlite.Table>` instance has a 

  :attr:`cursor <pydblite.sqlite.Database.cursor>` attribute, so that raw SQL requests can 

  be executed. 

""" 

 

try: 

    import cStringIO as io 

 

    def to_str(val, encoding="utf-8"):  # encode a Unicode string to a Python 2 str 

        return val.encode(encoding) 

except ImportError: 

    import io 

    unicode = str  # used in tests 

 

    def to_str(val):  # leaves a Unicode unchanged 

        return val 

 

import datetime 

import re 

import traceback 

 

from .common import ExpressionGroup, Filter 

 

# test if sqlite is installed or raise exception 

try: 

    from sqlite3 import dbapi2 as sqlite 

    from sqlite3 import OperationalError 

except ImportError: 

    try: 

        from pysqlite2 import dbapi2 as sqlite 

        from pysqlite2._sqlite import OperationalError 

    except ImportError: 

        print("SQLite is not installed") 

        raise 

 

# compatibility with Python 2.3 

try: 

    set([]) 

except NameError: 

    from sets import Set as set  # NOQA 

 

 

# classes for CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP 

class CurrentDate: 

    def __call__(self): 

        return datetime.date.today().strftime('%Y-%M-%D') 

 

 

class CurrentTime: 

    def __call__(self): 

        return datetime.datetime.now().strftime('%h:%m:%s') 

 

 

class CurrentTimestamp: 

    def __call__(self): 

        return datetime.datetime.now().strftime('%Y-%M-%D %h:%m:%s') 

 

DEFAULT_CLASSES = [CurrentDate, CurrentTime, CurrentTimestamp] 

 

# functions to convert a value returned by a SQLite SELECT 

 

# CURRENT_TIME format is HH:MM:SS 

# CURRENT_DATE : YYYY-MM-DD 

# CURRENT_TIMESTAMP : YYYY-MM-DD HH:MM:SS 

 

c_time_fmt = re.compile('^(\d{2}):(\d{2}):(\d{2})$') 

c_date_fmt = re.compile('^(\d{4})-(\d{2})-(\d{2})$') 

c_tmsp_fmt = re.compile('^(\d{4})-(\d{2})-(\d{2}) (\d{2}):(\d{2}):(\d{2})') 

 

 

# DATE : convert YYYY-MM-DD to datetime.date instance 

def to_date(date): 

88    if date is None: 

        return None 

    mo = c_date_fmt.match(date) 

91    if not mo: 

        raise ValueError("Bad value %s for DATE format" % date) 

    year, month, day = [int(x) for x in mo.groups()] 

    return datetime.date(year, month, day) 

 

 

# TIME : convert HH-MM-SS to datetime.time instance 

def to_time(_time): 

    if _time is None: 

        return None 

    mo = c_time_fmt.match(_time) 

    if not mo: 

        raise ValueError("Bad value %s for TIME format" % _time) 

    hour, minute, second = [int(x) for x in mo.groups()] 

    return datetime.time(hour, minute, second) 

 

 

# DATETIME or TIMESTAMP : convert %YYYY-MM-DD HH:MM:SS 

# to datetime.datetime instance 

def to_datetime(timestamp): 

    if timestamp is None: 

        return None 

    if not isinstance(timestamp, unicode): 

        raise ValueError("Bad value %s for TIMESTAMP format" % timestamp) 

    mo = c_tmsp_fmt.match(timestamp) 

    if not mo: 

        raise ValueError("Bad value %s for TIMESTAMP format" % timestamp) 

    return datetime.datetime(*[int(x) for x in mo.groups()]) 

 

 

# if default value is CURRENT_DATE etc. SQLite doesn't 

# give the information, default is the value of the 

# variable as a string. We have to guess... 

# 

def guess_default_fmt(value): 

    mo = c_time_fmt.match(value) 

127    if mo: 

        h, m, s = [int(x) for x in mo.groups()] 

        if (0 <= h <= 23) and (0 <= m <= 59) and (0 <= s <= 59): 

            return CurrentTime 

    mo = c_date_fmt.match(value) 

132    if mo: 

        y, m, d = [int(x) for x in mo.groups()] 

        try: 

            datetime.date(y, m, d) 

            return CurrentDate 

        except: 

            pass 

    mo = c_tmsp_fmt.match(value) 

140    if mo: 

        y, mth, d, h, mn, s = [int(x) for x in mo.groups()] 

        try: 

            datetime.datetime(y, mth, d, h, mn, s) 

            return CurrentTimestamp 

        except: 

            pass 

    return value 

 

 

class SQLiteError(Exception): 

    """SQLiteError""" 

    pass 

 

 

class Database(dict): 

 

    def __init__(self, filename, **kw): 

        """ 

        To create an in-memory database provide ':memory:' as filename 

 

        Args: 

            - filename (str): The name of the database file, or ':memory:' 

            - kw (dict): Arguments forwarded to sqlite3.connect 

        """ 

        dict.__init__(self) 

        self.conn = sqlite.connect(filename, **kw) 

        """The SQLite connection""" 

        self.cursor = self.conn.cursor() 

        """The SQLite connections cursor""" 

        for table_name in self._tables(): 

            self[table_name] = Table(table_name, self) 

 

    def _tables(self): 

        """Return the list of table names in the database""" 

        tables = [] 

        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") 

        for table_info in self.cursor.fetchall(): 

176            if table_info[0] != 'sqlite_sequence': 

                tables.append(table_info[0]) 

        return tables 

 

    def create(self, table_name, *fields, **kw): 

        self[table_name] = Table(table_name, self).create(*fields, **kw) 

        return self[table_name] 

 

    def commit(self): 

        """Save any changes to the database""" 

        self.conn.commit() 

 

    def __delitem__(self, table): 

        # drop table 

        if isinstance(table, Table): 

            table = table.name 

        self.cursor.execute('DROP TABLE %s' % table) 

        dict.__delitem__(self, table) 

 

 

class Table(object): 

 

    def __init__(self, table_name, db): 

        """ 

        Args: 

 

           - table_name (str): The name of the SQLite table. 

           - db (:class:`Database <pydblite.sqlite.Database>`): The database. 

 

        """ 

        self.name = table_name 

        self.db = db 

        self.cursor = db.cursor 

        """The SQLite connections cursor""" 

        self.conv_func = {} 

        self.mode = "open" 

        self._get_table_info() 

 

    def create(self, *fields, **kw): 

        """ 

        Create a new table. 

 

        Args: 

           - fields (list of tuples): The fields names/types to create. 

             For each field, a 2-element tuple must be provided: 

 

             - the field name 

             - a string with additional information like field type + 

               other information using the SQLite syntax 

               eg  ('name', 'TEXT NOT NULL'), ('date', 'BLOB DEFAULT CURRENT_DATE') 

 

           - mode (str): The mode used when creating the database. 

                  mode is only used if a database file already exists. 

 

             - if mode = 'open' : open the existing base, ignore the fields 

             - if mode = 'override' : erase the existing base and create a 

               new one with the specified fields 

 

        Returns: 

            - the database (self). 

        """ 

        self.mode = mode = kw.get("mode", None) 

 

        if self._table_exists(): 

            if mode == "override": 

                self.cursor.execute("DROP TABLE %s" % self.name) 

            elif mode == "open": 

                return self.open() 

            else: 

                raise IOError("Base '%s' already exists" % self.name) 

 

        sql = "CREATE TABLE %s (" % self.name 

        for field in fields: 

            sql += self._validate_field(field) + ',' 

        sql = sql[:-1] + ')' 

        self.cursor.execute(sql) 

        self._get_table_info() 

        return self 

 

    def open(self): 

        """Open an existing database.""" 

        return self 

 

    def commit(self): 

        """Save any changes to the database""" 

        self.db.commit() 

 

    def _table_exists(self): 

        return self.name in self.db 

 

    def _get_table_info(self): 

        """Inspect the base to get field names.""" 

        self.fields = [] 

        self.field_info = {} 

        self.cursor.execute('PRAGMA table_info (%s)' % self.name) 

        for field_info in self.cursor.fetchall(): 

            fname = to_str(field_info[1]) 

            self.fields.append(fname) 

            ftype = to_str(field_info[2]) 

            info = {'type': ftype} 

            # can be null ? 

            info['NOT NULL'] = field_info[3] != 0 

            # default value 

            default = field_info[4] 

            if isinstance(default, unicode): 

                default = guess_default_fmt(default) 

            info['DEFAULT'] = default 

            self.field_info[fname] = info 

        self.fields_with_id = ['__id__'] + self.fields 

 

    def info(self): 

        # returns information about the table 

        return [(field, self.field_info[field]) for field in self.fields] 

 

    def _validate_field(self, field): 

293        if len(field) != 2 and len(field) != 3: 

            msg = "Error in field definition %s" % field 

            msg += ": should be a tuple with field_name, field_info, and optionally a default value" 

            raise SQLiteError(msg) 

        field_sql = '%s %s' % (field[0], field[1]) 

        if len(field) == 3 and field[2] is not None: 

            field_sql += " DEFAULT {0}".format(field[2]) 

        return field_sql 

 

    def conv(self, field_name, conv_func): 

        """When a record is returned by a SELECT, ask conversion of 

        specified field value with the specified function.""" 

305        if field_name not in self.fields: 

            raise NameError("Unknown field %s" % field_name) 

        self.conv_func[field_name] = conv_func 

 

    def is_date(self, field_name): 

        """Ask conversion of field to an instance of datetime.date""" 

        self.conv(field_name, to_date) 

 

    def is_time(self, field_name): 

        """Ask conversion of field to an instance of datetime.date""" 

        self.conv(field_name, to_time) 

 

    def is_datetime(self, field_name): 

        """Ask conversion of field to an instance of datetime.date""" 

        self.conv(field_name, to_datetime) 

 

    def insert(self, *args, **kw): 

        """Insert a record in the database. 

 

        Parameters can be positional or keyword arguments. If positional 

        they must be in the same order as in the :func:`create` method. 

 

        Returns: 

            - The record identifier 

        """ 

        if args: 

            if isinstance(args[0], (list, tuple)): 

                return self._insert_many(args[0]) 

            kw = dict([(f, arg) for f, arg in zip(self.fields, args)]) 

 

        ks = kw.keys() 

        s1 = ",".join(ks) 

        qm = ','.join(['?'] * len(ks)) 

        sql = "INSERT INTO %s (%s) VALUES (%s)" % (self.name, s1, qm) 

        self.cursor.execute(sql, list(kw.values())) 

        return self.cursor.lastrowid 

 

    def _insert_many(self, args): 

        """Insert a list or tuple of records 

 

        Returns: 

            - The last row id 

        """ 

        sql = "INSERT INTO %s" % self.name 

        sql += "(%s) VALUES (%s)" 

        if isinstance(args[0], dict): 

            ks = args[0].keys() 

            sql = sql % (', '.join(ks), ','.join(['?' for k in ks])) 

            args = [[arg[k] for k in ks] for arg in args] 

        else: 

            sql = sql % (', '.join(self.fields), 

                         ','.join(['?' for f in self.fields])) 

        try: 

            self.cursor.executemany(sql, args) 

        except: 

            raise Exception(self._err_msg(sql, args)) 

        # return last row id 

        return self.cursor.lastrowid 

 

    def delete(self, removed): 

        """Remove a single record, or the records in an iterable. 

 

        Before starting deletion, test if all records are in the base 

        and don't have twice the same __id__. 

 

        Returns: 

             - int: the number of deleted items 

        """ 

        sql = "DELETE FROM %s " % self.name 

381        if isinstance(removed, dict): 

            # remove a single record 

            _id = removed['__id__'] 

            sql += "WHERE rowid = ?" 

            args = (_id,) 

            removed = [removed] 

        else: 

            # convert iterable into a list 

            removed = [r for r in removed] 

            if not removed: 

                return 0 

            args = [r['__id__'] for r in removed] 

            sql += "WHERE rowid IN (%s)" % (','.join(['?'] * len(args))) 

        self.cursor.execute(sql, args) 

        self.db.commit() 

        return len(removed) 

 

    def update(self, record, **kw): 

        """Update the record with new keys and values.""" 

        vals = self._make_sql_params(kw) 

        sql = "UPDATE %s SET %s WHERE rowid=?" % (self.name, 

                                                  ",".join(vals)) 

        self.cursor.execute(sql, list(kw.values()) + [record['__id__']]) 

        self.db.commit() 

 

    def _make_sql_params(self, kw): 

        """Make a list of strings to pass to an SQL statement 

        from the dictionary kw with Python types.""" 

        return ['%s=?' % k for k in kw.keys()] 

 

    def _make_record(self, row, fields=None): 

        """Make a record dictionary from the result of a fetch_""" 

407        if fields is None: 

            fields = self.fields_with_id 

        res = dict(zip(fields, row)) 

        for field_name in self.conv_func: 

            res[field_name] = self.conv_func[field_name](res[field_name]) 

        return res 

 

    def add_field(self, name, column_type="TEXT", default=None): 

        """Add a new column to the table. 

 

        Args: 

           - name (string): The name of the field 

           - column_type (string): The data type of the column (Defaults to TEXT) 

           - default (datatype): The default value for this field (if any) 

 

        """ 

        sql = "ALTER TABLE %s ADD " % self.name 

        sql += self._validate_field((name, column_type, default)) 

        self.cursor.execute(sql) 

        self.db.commit() 

        self._get_table_info() 

 

    def drop_field(self, field): 

        raise SQLiteError("Dropping fields is not supported by SQLite") 

 

    def __call__(self, *args, **kw): 

        """ 

        Selection by field values. 

 

        db(key=value) returns the list of records where r[key] = value 

 

        Args: 

           - args (list): A field to filter on. 

           - kw (dict): pairs of field and value to filter on. 

 

        Returns: 

           - When args supplied, return a :class:`Filter <pydblite.common.Filter>` 

             object that filters on the specified field. 

           - When kw supplied, return all the records where field values matches 

             the key/values in kw. 

 

        """ 

448        if args and kw: 

            raise SyntaxError("Can't specify positional AND keyword arguments") 

 

        use_expression = False 

        if args: 

453            if len(args) > 1: 

                raise SyntaxError("Only one field can be specified") 

455            if type(args[0]) is ExpressionGroup or type(args[0]) is Filter: 

                use_expression = True 

457            elif args[0] not in self.fields: 

                raise ValueError("%s is not a field" % args[0]) 

            else: 

                return self.filter(key=args[0]) 

 

462        if use_expression: 

            sql = "SELECT rowid,* FROM %s WHERE %s" % (self.name, args[0]) 

            self.cursor.execute(sql) 

            return [self._make_record(row) for row in self.cursor.fetchall()] 

        else: 

            if kw: 

                undef = set(kw) - set(self.fields) 

469                if undef: 

                    raise ValueError("Fields %s not in the database" % undef) 

                vals = self._make_sql_params(kw) 

                sql = "SELECT rowid,* FROM %s WHERE %s" % (self.name, " AND ".join(vals)) 

                self.cursor.execute(sql, list(kw.values())) 

            else: 

                self.cursor.execute("SELECT rowid,* FROM %s" % self.name) 

            records = self.cursor.fetchall() 

            return [self._make_record(row) for row in records] 

 

    def __getitem__(self, record_id): 

        """Direct access by record id.""" 

        sql = "SELECT rowid,* FROM %s WHERE rowid=%s" % (self.name, record_id) 

        self.cursor.execute(sql) 

        res = self.cursor.fetchone() 

        if res is None: 

            raise IndexError("No record at index %s" % record_id) 

        else: 

            return self._make_record(res) 

 

    def filter(self, key=None): 

        return Filter(self, key) 

 

    def _len(self, db_filter=None): 

        if db_filter and str(db_filter): 

            sql = "SELECT COUNT(*) AS count FROM %s WHERE %s" % (self.name, db_filter) 

        else: 

            sql = "SELECT COUNT(*) AS count FROM %s;" % self.name 

        self.cursor.execute(sql) 

        res = self.cursor.fetchone() 

        return res[0] 

 

    def __len__(self): 

        return self._len() 

 

    def __delitem__(self, record_id): 

        """Delete by record id""" 

        self.delete(self[record_id]) 

 

    def __iter__(self): 

        """Iteration on the records""" 

        self.cursor.execute("SELECT rowid,* FROM %s" % self.name) 

        results = [self._make_record(r) for r in self.cursor.fetchall()] 

        return iter(results) 

 

    def _err_msg(self, sql, args=None): 

        msg = "Exception for table %s.%s\n" % (self.db, self.name) 

        msg += 'SQL request %s\n' % sql 

        if args: 

            import pprint 

            msg += 'Arguments : %s\n' % pprint.saferepr(args) 

        out = io.StringIO() 

        traceback.print_exc(file=out) 

        msg += out.getvalue() 

        return msg 

 

    def get_group_count(self, group_by, db_filter=None): 

526        if db_filter and str(db_filter): 

            sql = "SELECT %s, COUNT(*) FROM %s GROUP BY %s WHERE %s" % (group_by, self.name, 

                                                                        group_by, db_filter) 

        else: 

            sql = "SELECT %s, COUNT(*) FROM %s GROUP BY %s;" % (group_by, self.name, group_by) 

        self.cursor.execute(sql) 

        return self.cursor.fetchall() 

 

    def get_unique_ids(self, unique_id, db_filter=None): 

        sql = "SELECT rowid,%s FROM %s" % (unique_id, self.name) 

536        if db_filter and str(db_filter): 

            sql += " WHERE %s" % db_filter 

        self.cursor.execute(sql) 

        records = self.cursor.fetchall() 

        return set([row[1] for row in records]) 

 

    def create_index(self, *index_columns): 

        for ic in index_columns: 

            sql = "CREATE INDEX index_%s on %s (%s);" % (ic, self.name, ic) 

            self.cursor.execute(sql) 

        self.db.commit() 

 

    def delete_index(self, *index_columns): 

        for ic in index_columns: 

            sql = "DROP INDEX index_%s;" % (ic) 

            self.cursor.execute(sql) 

        self.db.commit() 

 

    def get_indices(self): 

        indices = [] 

        sql = "SELECT * FROM sqlite_master WHERE type = 'index';" 

        try: 

            self.cursor.execute(sql) 

        except OperationalError: 

            return indices 

 

        records = self.cursor.fetchall() 

        for r in records: 

            indices.append(r[1][len("index_"):]) 

        return indices 

 

Base = Table  # compatibility with previous versions