Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251

252

253

254

255

256

257

258

259

260

261

262

263

264

265

266

267

268

269

270

271

272

273

274

275

276

277

278

279

280

281

282

283

284

285

286

287

288

289

290

291

292

293

294

295

296

297

298

299

300

301

302

303

304

305

306

307

308

309

310

311

312

313

314

315

316

317

318

319

320

321

322

323

324

325

326

327

328

329

330

331

332

333

334

335

336

337

338

339

340

341

342

343

344

345

346

347

348

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367

368

369

370

371

372

373

374

375

376

377

378

379

380

381

382

383

384

385

386

387

388

389

390

391

392

393

394

395

396

397

398

399

400

401

402

403

404

405

406

407

408

409

410

411

412

413

414

415

416

417

418

419

420

421

422

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445

446

447

448

449

450

451

452

453

454

455

456

457

458

459

460

461

462

463

464

465

466

467

468

469

470

471

472

473

474

475

476

477

478

479

480

481

482

483

484

485

486

487

488

489

490

491

492

493

494

495

496

497

498

499

500

501

502

503

504

505

506

507

508

509

510

511

512

513

514

515

516

517

518

519

520

521

522

523

524

525

526

527

528

529

530

531

532

533

534

535

536

537

538

539

540

# 

# BSD licence 

# 

# Author : Pierre Quentel (pierre.quentel@gmail.com) 

# 

 

import bisect 

import operator 

import os 

import sys 

from itertools import groupby 

 

from .common import Expression, ExpressionGroup, Filter 

 

try: 

    import cPickle as pickle 

except: 

    import pickle 

 

version = "3.0.1" 

 

 

def _in(a, b): 

    return operator.contains(b, a) 

 

 

def like(a, b): 

    return operator.contains(a.lower(), b.lower()) 

 

 

class PyDbExpression(Expression): 

 

    def __init__(self, **kwargs): 

        super(PyDbExpression, self).__init__(**kwargs) 

        self.operations = {'AND': 'AND', 'OR': 'OR', 

                           'LIKE': like, 

                           'GLOB': operator.contains, 

                           "IN": _in, 

                           '=': operator.eq, '!=': operator.ne, '<': operator.lt, 

                           '<=': operator.le, '>': operator.gt, '>=': operator.ge} 

 

    def apply(self, records): 

        operation = self.operations[self.operator] 

        records = [r for r in records if operation(r[self.key], self.value)] 

        return records 

 

 

class PyDbExpressionGroup(ExpressionGroup): 

 

    def apply_filter(self, records): 

52        if self.is_dummy(): 

            return "" 

        if self.expression: 

            return self.expression.apply(records.values()) 

        else: 

            # Parent of two expressions 

            records1 = self.exp_group1.apply_filter(records) 

            records2 = self.exp_group2.apply_filter(records) 

            if self.exp_operator == Filter.operations.AND: 

                ids1 = dict([(id(r), r) for r in records1]) 

                ids2 = dict([(id(r), r) for r in records2]) 

                ids = set(ids1.keys()) & set(ids2.keys()) 

                records = [ids1[_id] for _id in ids] 

            else: 

                ids = dict([(id(r), r) for r in records1]) 

                ids.update(dict([(id(r), r) for r in records2])) 

                records = ids.values() 

            return records 

 

 

class PyDbFilter(Filter): 

 

    def __init__(self, db, key): 

        self.db = db 

        self.key = key 

        self.expression_group = PyDbExpressionGroup() 

        self.expression_t = PyDbExpression 

 

    def apply_filter(self, records): 

        return self.expression_group.apply_filter(records) 

 

 

class Index(object): 

    """Class used for indexing a base on a field. 

    The instance of Index is an attribute of the Base instance""" 

 

    def __init__(self, db, field): 

        self.db = db  # database object (instance of Base) 

        self.field = field  # field name 

 

    def __iter__(self): 

        return iter(self.db.indices[self.field]) 

 

    def keys(self): 

        return self.db.indices[self.field].keys() 

 

    def __getitem__(self, key): 

        """Lookup by key : return the list of records where 

        field value is equal to this key, or an empty list""" 

        ids = self.db.indices[self.field].get(key, []) 

        return [self.db.records[_id] for _id in ids] 

 

 

class _Base(object): 

 

    def __init__(self, path, protocol=pickle.HIGHEST_PROTOCOL, save_to_file=True, 

                 sqlite_compat=False): 

        """protocol as defined in pickle / pickle. 

        Defaults to the highest protocol available. 

        For maximum compatibility use protocol = 0 

 

        """ 

        self.path = path 

        """The path of the database in the file system""" 

        self.name = os.path.splitext(os.path.basename(path))[0] 

        """The basename of the path, stripped of its extension""" 

        self.protocol = protocol 

        self.mode = None 

        if path == ":memory:": 

            save_to_file = False 

        self.save_to_file = save_to_file 

        self.sqlite_compat = sqlite_compat 

        self.fields = [] 

        """The list of the fields (does not include the internal 

        fields __id__ and __version__)""" 

        # if base exists, get field names 

        if save_to_file and self.exists(): 

129            if protocol == 0: 

                _in = open(self.path)  # don't specify binary mode ! 

            else: 

                _in = open(self.path, 'rb') 

            self.fields = pickle.load(_in) 

 

    def exists(self): 

        """ 

        Returns: 

            - bool: if the database file exists 

        """ 

        return os.path.isfile(self.path) 

 

    def create(self, *fields, **kw): 

        """ 

        Create a new base with specified field names. 

 

        Args: 

            - \*fields (str): The field names to create. 

            - mode (str): the mode used when creating the database. 

 

        - if mode = 'open' : open the existing base, ignore the fields 

        - if mode = 'override' : erase the existing base and create a 

          new one with the specified fields 

 

        Returns: 

            - Returns the database (self). 

        """ 

        self.mode = mode = kw.get("mode", None) 

        if self.save_to_file and os.path.exists(self.path): 

            if not os.path.isfile(self.path): 

                raise IOError("%s exists and is not a file" % self.path) 

            elif mode is None: 

                raise IOError("Base %s already exists" % self.path) 

            elif mode == "open": 

                return self.open() 

            elif mode == "override": 

                os.remove(self.path) 

            else: 

                raise ValueError("Invalid value given for 'open': '%s'" % open) 

 

        self.fields = [] 

        self.default_values = {} 

        for field in fields: 

173            if type(field) is dict: 

                self.fields.append(field["name"]) 

                self.default_values[field["name"]] = field.get("default", None) 

176            elif type(field) is tuple: 

                self.fields.append(field[0]) 

                self.default_values[field[0]] = field[1] 

            else: 

                self.fields.append(field) 

                self.default_values[field] = None 

 

        self.records = {} 

        self.next_id = 0 

        self.indices = {} 

        self.commit() 

        return self 

 

    def create_index(self, *fields): 

        """ 

        Create an index on the specified field names 

 

        An index on a field is a mapping between the values taken by the field 

        and the sorted list of the ids of the records whose field is equal to 

        this value 

 

        For each indexed field, an attribute of self is created, an instance 

        of the class Index (see above). Its name it the field name, with the 

        prefix _ to avoid name conflicts 

 

        Args: 

            - fields (list): the fields to index 

        """ 

        reset = False 

        for f in fields: 

206            if f not in self.fields: 

                raise NameError("%s is not a field name %s" % (f, self.fields)) 

            # initialize the indices 

209            if self.mode == "open" and f in self.indices: 

                continue 

            reset = True 

            self.indices[f] = {} 

214            for _id, record in self.records.items(): 

                # use bisect to quickly insert the id in the list 

                bisect.insort(self.indices[f].setdefault(record[f], []), _id) 

            # create a new attribute of self, used to find the records 

            # by this index 

            setattr(self, '_' + f, Index(self, f)) 

exit        if reset: 

            self.commit() 

 

    def delete_index(self, *fields): 

        """Delete the index on the specified fields""" 

        for f in fields: 

225            if f not in self.indices: 

                raise ValueError("No index on field %s" % f) 

        for f in fields: 

            del self.indices[f] 

        self.commit() 

 

    def open(self): 

        """Open an existing database and load its content into memory""" 

        # guess protocol 

234        if self.protocol == 0: 

            _in = open(self.path)  # don't specify binary mode ! 

        else: 

            _in = open(self.path, 'rb') 

        self.fields = pickle.load(_in) 

        self.next_id = pickle.load(_in) 

        self.records = pickle.load(_in) 

        self.indices = pickle.load(_in) 

        try: 

            # If loading an old database, the default values do not exist 

            self.default_values = pickle.load(_in) 

        except EOFError: 

            self.default_values = {} 

247        for f in self.indices.keys(): 

            setattr(self, '_' + f, Index(self, f)) 

        _in.close() 

        self.mode = "open" 

        return self 

 

    def commit(self): 

        """Write the database to a file""" 

        if self.save_to_file is False: 

            return 

        out = open(self.path, 'wb') 

        pickle.dump(self.fields, out, self.protocol) 

        pickle.dump(self.next_id, out, self.protocol) 

        pickle.dump(self.records, out, self.protocol) 

        pickle.dump(self.indices, out, self.protocol) 

        pickle.dump(self.default_values, out, self.protocol) 

        out.close() 

 

    def insert(self, *args, **kw): 

        """ 

        Insert one or more records in the database. 

 

        Parameters can be positional or keyword arguments. If positional 

        they must be in the same order as in the create() method 

        If some of the fields are missing the value is set to None 

 

        Args: 

            - args (values, or a list/tuple of values): The record(s) to insert. 

            - kw (dict): The field/values to insert 

 

        Returns: 

            - Returns the record identifier if inserting one item, else None. 

        """ 

280        if not self.mode: 

            raise RuntimeError("Database columns have not been setup!") 

        if args: 

            if self.sqlite_compat and isinstance(args[0], (list, tuple)): 

                for e in args[0]: 

                    if type(e) is dict: 

                        self.insert(**e) 

                    else: 

                        self.insert(*e) 

                return None 

            kw = dict([(f, arg) for f, arg in zip(self.fields, args)]) 

        # initialize all fields to the default values 

        import copy 

        record = copy.deepcopy(self.default_values) 

        # raise exception if unknown field 

        for key in kw: 

296            if key not in self.fields: 

                raise NameError("Invalid field name : %s" % key) 

        # set keys and values 

        for (k, v) in kw.items(): 

            record[k] = v 

        # add the key __id__ : record identifier 

        record['__id__'] = self.next_id 

        # add the key __version__ : version number 

        record['__version__'] = 0 

        # create an entry in the dictionary self.records, indexed by __id__ 

        self.records[self.next_id] = record 

        # update index 

308        for ix in self.indices.keys(): 

            bisect.insort(self.indices[ix].setdefault(record[ix], []), self.next_id) 

        # increment the next __id__ 

        self.next_id += 1 

        return record['__id__'] 

 

    def delete(self, remove): 

        """ 

        Remove a single record, or the records in an iterable 

 

        Before starting deletion, test if all records are in the base 

        and don't have twice the same __id__ 

 

        Args: 

            - remove (record or list of records): The record(s) to delete. 

 

        Returns: 

            - Return the number of deleted items 

        """ 

330        if isinstance(remove, dict): 

            remove = [remove] 

        else: 

            # convert iterable into a list (to be able to sort it) 

            remove = [r for r in remove] 

332        if not remove: 

            return 0 

        _ids = [r['__id__'] for r in remove] 

        _ids.sort() 

        keys = set(self.records.keys()) 

        # check if the records are in the base 

338        if not set(_ids).issubset(keys): 

            missing = list(set(_ids).difference(keys)) 

            raise IndexError('Delete aborted. Records with these ids' 

                             ' not found in the base : %s' % str(missing)) 

        # raise exception if duplicate ids 

343        for i in range(len(_ids) - 1): 

            if _ids[i] == _ids[i + 1]: 

                raise IndexError("Delete aborted. Duplicate id : %s" % _ids[i]) 

        deleted = len(remove) 

        while remove: 

            r = remove.pop() 

            _id = r['__id__'] 

            # remove id from indices 

351            for indx in self.indices.keys(): 

                pos = bisect.bisect(self.indices[indx][r[indx]], _id) - 1 

                del self.indices[indx][r[indx]][pos] 

                if not self.indices[indx][r[indx]]: 

                    del self.indices[indx][r[indx]] 

            # remove record from self.records 

            del self.records[_id] 

        return deleted 

 

    def update(self, records, **kw): 

        """ 

        Update one record or a list of records 

        with new keys and values and update indices 

 

        Args: 

           - records (record or list of records): The record(s) to update. 

        """ 

        # ignore unknown fields 

        kw = dict([(k, v) for (k, v) in kw.items() if k in self.fields]) 

372        if isinstance(records, dict): 

            records = [records] 

        # update indices 

373        for indx in set(self.indices.keys()) & set(kw.keys()): 

            for record in records: 

                if record[indx] == kw[indx]: 

                    continue 

                _id = record["__id__"] 

                # remove id for the old value 

                old_pos = bisect.bisect(self.indices[indx][record[indx]], _id) - 1 

                del self.indices[indx][record[indx]][old_pos] 

                if not self.indices[indx][record[indx]]: 

                    del self.indices[indx][record[indx]] 

                # insert new value 

                bisect.insort(self.indices[indx].setdefault(kw[indx], []), _id) 

        for record in records: 

            # update record values 

            record.update(kw) 

            # increment version number 

            record["__version__"] += 1 

 

    def add_field(self, field, column_type="ignored", default=None): 

        """Adds a field to the database""" 

393        if field in self.fields + ["__id__", "__version__"]: 

            raise ValueError("Field %s already defined" % field) 

395        if not hasattr(self, 'records'):  # base not open yet 

            self.open() 

        for r in self: 

            r[field] = default 

        self.fields.append(field) 

        self.default_values[field] = default 

        self.commit() 

 

    def drop_field(self, field): 

        """Removes a field from the database""" 

        if field in ["__id__", "__version__"]: 

            raise ValueError("Can't delete field %s" % field) 

        self.fields.remove(field) 

        for r in self: 

            del r[field] 

        if field in self.indices: 

            del self.indices[field] 

        self.commit() 

 

    def __call__(self, *args, **kw): 

        """Selection by field values 

 

        db(key=value) returns the list of records where r[key] = value 

 

        Args: 

            - args (list): A field to filter on. 

            - kw (dict): pairs of field and value to filter on. 

 

        Returns: 

            - When args supplied, return a Filter object that filters on 

              the specified field. 

            - When kw supplied, return all the records where field values matches the 

              key/values in kw. 

        """ 

429        if args and kw: 

            raise SyntaxError("Can't specify positional AND keyword arguments") 

 

        if args: 

433            if len(args) > 1: 

                raise SyntaxError("Only one field can be specified") 

435            elif (type(args[0]) is PyDbExpressionGroup or type(args[0]) is PyDbFilter): 

                return args[0].apply_filter(self.records) 

437            elif args[0] not in self.fields: 

                raise ValueError("%s is not a field" % args[0]) 

            else: 

                return PyDbFilter(self, args[0]) 

        if not kw: 

            return self.records.values()  # db() returns all the values 

 

        # indices and non-indices 

        keys = kw.keys() 

        ixs = set(keys) & set(self.indices.keys()) 

        no_ix = set(keys) - ixs 

449        if ixs: 

            # fast selection on indices 

            ix = ixs.pop() 

            res = set(self.indices[ix].get(kw[ix], [])) 

            if not res: 

                return [] 

            while ixs: 

                ix = ixs.pop() 

                res = res & set(self.indices[ix].get(kw[ix], [])) 

        else: 

            # if no index, initialize result with test on first field 

            field = no_ix.pop() 

            res = set([r["__id__"] for r in self if r[field] == kw[field]]) 

        # selection on non-index fields 

462        for field in no_ix: 

            res = res & set([_id for _id in res if self.records[_id][field] == kw[field]]) 

        return [self[_id] for _id in res] 

 

    def __getitem__(self, key): 

        # direct access by record id 

        return self.records[key] 

 

    def _len(self, db_filter=None): 

        if db_filter is not None: 

472            if not type(db_filter) is PyDbExpressionGroup: 

                raise ValueError("Filter argument is not of type " 

                                 "'PyDbExpressionGroup': %s" % type(db_filter)) 

476            if db_filter.is_filtered(): 

                return len(db_filter.apply_filter(self.records)) 

        return len(self.records) 

 

    def __len__(self): 

        return self._len() 

 

    def __delitem__(self, record_id): 

        """Delete by record id""" 

        self.delete(self[record_id]) 

 

    def __contains__(self, record_id): 

        return record_id in self.records 

 

    def group_by(self, column, torrents_filter): 

        """Returns the records grouped by column""" 

        gropus = [(k, len(list(g))) for k, g in groupby(torrents_filter, 

                                                        key=lambda x: x[column])] 

        result = {} 

        for column, count in gropus: 

            result[column] = result.get(column, 0) + count 

        return [(c, result[c]) for c in result] 

 

    def filter(self, key=None): 

        return PyDbFilter(self, key) 

 

    def get_group_count(self, group_by_field, db_filter=None): 

504        if db_filter is None: 

            db_filter = self.filter() 

 

        gropus = [(k, len(list(g))) for k, g in groupby(db_filter, 

                                                        key=lambda x: x[group_by_field])] 

        groups_dict = {} 

        for group, count in gropus: 

            groups_dict[group] = groups_dict.get(group, 0) + count 

        return [(k, groups_dict[k]) for k in groups_dict] 

 

    def get_unique_ids(self, id_value, db_filter=None): 

        """Returns a set of unique values from column""" 

514        if db_filter is not None and db_filter.is_filtered(): 

            records = self(db_filter) 

        else: 

            records = self() 

        return set([row[id_value] for row in records]) 

 

    def get_indices(self): 

        """Returns the indices""" 

        return list(self.indices) 

 

 

class _BasePy2(_Base): 

 

    def __iter__(self): 

        """Iteration on the records""" 

        return iter(self.records.itervalues()) 

 

 

class _BasePy3(_Base): 

 

    def __iter__(self): 

        """Iteration on the records""" 

        return iter(self.records.values()) 

 

540if sys.version_info[0] == 2: 

    Base = _BasePy2 

else: 

    Base = _BasePy3