Commit 56761cd4 authored by nkindlon's avatar nkindlon
Browse files

Changes for multi-DB intersection.

parent ccd125c1
......@@ -37,7 +37,7 @@ FileIntersect::~FileIntersect(void) {
delete _recordOutputMgr;
}
void FileIntersect::processHits(RecordKeyList &hits) {
void FileIntersect::processHits(RecordKeyVector &hits) {
_recordOutputMgr->printRecord(hits);
}
......@@ -58,11 +58,11 @@ bool FileIntersect::processSortedFiles()
return false;
}
RecordKeyList hitSet;
RecordKeyVector hitSet;
while (sweep.next(hitSet)) {
if (_context->getObeySplits()) {
RecordKeyList keySet(hitSet.getKey());
RecordKeyList resultSet(hitSet.getKey());
RecordKeyVector keySet(hitSet.getKey());
RecordKeyVector resultSet(hitSet.getKey());
_blockMgr->findBlockedOverlaps(keySet, hitSet, resultSet);
processHits(resultSet);
} else {
......@@ -85,11 +85,11 @@ bool FileIntersect::processUnsortedFiles()
if (queryRecord == NULL) {
continue;
}
RecordKeyList hitSet(queryRecord);
RecordKeyVector hitSet(queryRecord);
binTree->getHits(queryRecord, hitSet);
if (_context->getObeySplits()) {
RecordKeyList keySet(hitSet.getKey());
RecordKeyList resultSet;
RecordKeyVector keySet(hitSet.getKey());
RecordKeyVector resultSet;
_blockMgr->findBlockedOverlaps(keySet, hitSet, resultSet);
processHits(resultSet);
} else {
......
......@@ -14,7 +14,7 @@
using namespace std;
#include "RecordKeyList.h"
#include "RecordKeyVector.h"
using namespace std;
......@@ -37,7 +37,7 @@ private:
BlockMgr *_blockMgr;
RecordOutputMgr *_recordOutputMgr;
void processHits(RecordKeyList &hits);
void processHits(RecordKeyVector &hits);
bool processSortedFiles();
bool processUnsortedFiles();
......
......@@ -13,6 +13,7 @@ using namespace std;
#include "intersectFile.h"
#include "ContextIntersect.h"
#include "CommonHelp.h"
// define our program name
#define PROGRAM_NAME "bedtools intersect"
......@@ -46,6 +47,9 @@ void intersect_help(void) {
cerr << "Usage: " << PROGRAM_NAME << " [OPTIONS] -a <bed/gff/vcf> -b <bed/gff/vcf>" << endl << endl;
cerr << "\t\t" << "Note: -b may be followed with multiple databases and/or " << endl;
cerr << "\t\t" "wildcard (*) character(s). " << endl;
cerr << "Options: " << endl;
cerr << "\t-abam\t" << "The A input file is in BAM format. Output will be BAM as well." << endl << endl;
......@@ -119,6 +123,17 @@ void intersect_help(void) {
cerr <<"\t\tother software tools and scripts that need to process one" << endl;
cerr <<"\t\tline of bedtools output at a time." << endl << endl;
cerr << "\t-names\t" << "When using multiple databases, provide an alias for each that" << endl;
cerr <<"\t\twill appear instead of a fileId when also printing the DB record." << endl << endl;
cerr << "\t-filenames" << "\tWhen using multiple databases, show each complete filename" << endl;
cerr <<"\t\t\tinstead of a fileId when also printing the DB record." << endl << endl;
cerr << "\t-sortout\t" << "When using multiple databases, sort the output DB hits" << endl;
cerr << "\t\t\tfor each record." << endl << endl;
CommonHelp();
cerr << "Notes: " << endl;
cerr << "\t(1) When a BAM file is used for the A file, the alignment is retained if overlaps exist," << endl;
cerr << "\tand exlcuded if an overlap cannot be found. If multiple overlaps exist, they are not" << endl;
......
......@@ -49,11 +49,11 @@ bool Jaccard::getIntersectionAndUnion() {
if (!sweep.init()) {
return false;
}
RecordKeyList hitSet;
RecordKeyVector hitSet;
while (sweep.next(hitSet)) {
if (_context->getObeySplits()) {
RecordKeyList keySet(hitSet.getKey());
RecordKeyList resultSet(hitSet.getKey());
RecordKeyVector keySet(hitSet.getKey());
RecordKeyVector resultSet(hitSet.getKey());
_blockMgr->findBlockedOverlaps(keySet, hitSet, resultSet);
_intersectionVal += getTotalIntersection(&resultSet);
} else {
......@@ -69,7 +69,7 @@ bool Jaccard::getIntersectionAndUnion() {
return true;
}
unsigned long Jaccard::getTotalIntersection(RecordKeyList *recList)
unsigned long Jaccard::getTotalIntersection(RecordKeyVector *recList)
{
unsigned long intersection = 0;
const Record *key = recList->getKey();
......@@ -77,8 +77,8 @@ unsigned long Jaccard::getTotalIntersection(RecordKeyList *recList)
int keyEnd = key->getEndPos();
int hitIdx = 0;
for (RecordKeyList::const_iterator_type iter = recList->begin(); iter != recList->end(); iter = recList->next()) {
const Record *currRec = iter->value();
for (RecordKeyVector::const_iterator_type iter = recList->begin(); iter != recList->end(); iter = recList->next()) {
const Record *currRec = *iter;
int maxStart = max(currRec->getStartPos(), keyStart);
int minEnd = min(currRec->getEndPos(), keyEnd);
if (_context->getObeySplits()) {
......
......@@ -33,7 +33,7 @@ private:
int _numIntersections;
bool getIntersectionAndUnion();
unsigned long getTotalIntersection(RecordKeyList *hits);
unsigned long getTotalIntersection(RecordKeyVector *hits);
};
#endif /* JACCARD_H */
......@@ -41,11 +41,11 @@ bool FileMap::mapFiles()
if (!sweep.init()) {
return false;
}
RecordKeyList hitSet;
RecordKeyVector hitSet;
while (sweep.next(hitSet)) {
if (_context->getObeySplits()) {
RecordKeyList keySet(hitSet.getKey());
RecordKeyList resultSet(hitSet.getKey());
RecordKeyVector keySet(hitSet.getKey());
RecordKeyVector resultSet(hitSet.getKey());
_blockMgr->findBlockedOverlaps(keySet, hitSet, resultSet);
_recordOutputMgr->printRecord(resultSet.getKey(), _context->getColumnOpsVal(resultSet));
} else {
......
......@@ -17,7 +17,7 @@ using namespace std;
#include <sstream>
#include <iomanip>
#include "VectorOps.h"
#include "RecordKeyList.h"
#include "RecordKeyVector.h"
#include "ContextMap.h"
using namespace std;
......
......@@ -28,7 +28,7 @@ MergeFile::~MergeFile()
bool MergeFile::merge()
{
RecordKeyList hitSet;
RecordKeyVector hitSet;
FileRecordMgr *frm = _context->getFile(0);
while (!frm->eof()) {
Record *key = frm->getNextRecord(&hitSet);
......
......@@ -3,8 +3,7 @@
BinTree::BinTree(ContextIntersect *context)
: _databaseFile(NULL),
_context(context),
: _context(context),
_binOffsetsExtended(NULL),
_showBinMetrics(false),
_maxBinNumFound(0)
......@@ -36,7 +35,7 @@ BinTree::~BinTree() {
}
for (innerListIterType listIter = bin->begin(); listIter != bin->end(); listIter = bin->next()) {
const Record *record = listIter->value();
_databaseFile->deleteRecord(record);
_context->getFile(record->getFileIdx())->deleteRecord(record);
}
delete bin;
bin = NULL;
......@@ -70,25 +69,27 @@ BinTree::~BinTree() {
void BinTree::loadDB()
{
_databaseFile = _context->getFile(_context->getDatabaseFileIdx());
Record *record = NULL;
while (!_databaseFile->eof()) {
record = _databaseFile->getNextRecord();
//In addition to NULL records, we also don't want to add unmapped reads.
if (record == NULL || record->isUnmapped()) {
continue;
}
for (int i=0; i < _context->getNumDatabaseFiles(); i++) {
FileRecordMgr *databaseFile = _context->getDatabaseFile(i);
Record *record = NULL;
while (!databaseFile->eof()) {
record = databaseFile->getNextRecord();
//In addition to NULL records, we also don't want to add unmapped reads.
if (record == NULL || record->isUnmapped()) {
continue;
}
if (!addRecordToTree(record)) {
fprintf(stderr, "ERROR: Unable to add record to tree.\n");
_databaseFile->close();
exit(1);
if (!addRecordToTree(record)) {
fprintf(stderr, "ERROR: Unable to add record to tree.\n");
databaseFile->close();
exit(1);
}
}
}
}
void BinTree::getHits(Record *record, RecordKeyList &hitSet)
void BinTree::getHits(Record *record, RecordKeyVector &hitSet)
{
if (_showBinMetrics) {
return; //don't care about query entries just yet.
......@@ -149,6 +150,9 @@ void BinTree::getHits(Record *record, RecordKeyList &hitSet)
startBin >>= _binNextShift;
endBin >>= _binNextShift;
}
if (_context->getSortOutput()) {
hitSet.sortVector();
}
}
bool BinTree::addRecordToTree(const Record *record)
......
......@@ -28,11 +28,10 @@ public:
~BinTree();
void loadDB();
void getHits(Record *record, RecordKeyList &hitSet);
void getHits(Record *record, RecordKeyVector &hitSet);
private:
FileRecordMgr *_databaseFile;
ContextIntersect *_context;
//
......@@ -52,8 +51,8 @@ private:
static const uint32_t _binFirstShift = 14; /* How much to shift to get to finest bin. */
static const uint32_t _binNextShift = 3; /* How much to shift to get to next larger bin. */
typedef BTlist<const Record *> innerListType;
typedef const BTlistNode<const Record *> * innerListIterType;
typedef RecordList innerListType;
typedef const RecordListNode * innerListIterType;
typedef innerListType * binType;
typedef binType * allBinsType;
typedef QuickString mainKeyType;
......
......@@ -36,18 +36,17 @@ ContextBase::ContextBase()
_reciprocal(false),
_sameStrand(false),
_diffStrand(false),
_sortedInput(false),
_sortedInput(false),
_sortOutput(false),
_reportDBnameTags(false),
_reportDBfileNames(false),
_printHeader(false),
_printable(true),
_explicitBedOutput(false),
_queryFileIdx(-1),
_databaseFileIdx(-1),
_bamHeaderAndRefIdx(-1),
_maxNumDatabaseFields(0),
_useFullBamTags(false),
_reportCount(false),
_reportNames(false),
_reportScores(false),
_numOutputRecords(0),
_hasConstantSeed(false),
_seed(0),
......@@ -193,6 +192,9 @@ bool ContextBase::parseCmdArgs(int argc, char **argv, int skipFirstArgs) {
else if (strcmp(_argv[_i], "-delim") == 0) {
if (!handle_delim()) return false;
}
else if (strcmp(_argv[_i], "-sortout") == 0) {
if (!handle_sortout()) return false;
}
}
return true;
......@@ -210,7 +212,11 @@ bool ContextBase::isValidState()
return false;
}
if (hasColumnOpsMethods()) {
FileRecordMgr *dbFile = getFile(hasIntersectMethods() ? _databaseFileIdx : 0);
//TBD: Adjust column ops for multiple databases.
//For now, use last file.
// FileRecordMgr *dbFile = getFile(hasIntersectMethods() ? _databaseFileIdx : 0);
FileRecordMgr *dbFile = getFile(getNumInputFiles()-1);
_keyListOps->setDBfileType(dbFile->getFileType());
if (!_keyListOps->isValidColumnOps(dbFile)) {
return false;
......@@ -251,7 +257,7 @@ bool ContextBase::openFiles() {
_files.resize(_fileNames.size());
for (int i = 0; i < (int)_fileNames.size(); i++) {
FileRecordMgr *frm = getNewFRM(_fileNames[i]);
FileRecordMgr *frm = getNewFRM(_fileNames[i], i);
if (hasGenomeFile()) {
frm->setGenomeFile(_genomeFile);
}
......@@ -281,7 +287,7 @@ int ContextBase::getBamHeaderAndRefIdx() {
if (_files[_queryFileIdx]->getFileType() == FileRecordTypeChecker::BAM_FILE_TYPE) {
_bamHeaderAndRefIdx = _queryFileIdx;
} else {
_bamHeaderAndRefIdx = _databaseFileIdx;
_bamHeaderAndRefIdx = _dbFileIdxs[0];
}
return _bamHeaderAndRefIdx;
}
......@@ -492,6 +498,13 @@ bool ContextBase::handle_delim()
return true;
}
bool ContextBase::handle_sortout()
{
setSortOutput(true);
markUsed(_i - _skipFirstArgs);
return true;
}
void ContextBase::setColumnOpsMethods(bool val)
{
if (val && !_hasColumnOpsMethods) {
......@@ -501,20 +514,24 @@ void ContextBase::setColumnOpsMethods(bool val)
_hasColumnOpsMethods = val;
}
const QuickString &ContextBase::getColumnOpsVal(RecordKeyList &keyList) const {
const QuickString &ContextBase::getColumnOpsVal(RecordKeyVector &keyList) const {
if (!hasColumnOpsMethods()) {
return _nullStr;
}
return _keyListOps->getOpVals(keyList);
}
FileRecordMgr *ContextBase::getNewFRM(const QuickString &filename) {
if (!_useMergedIntervals) {
return new FileRecordMgr(filename);
} else {
FileRecordMgr *ContextBase::getNewFRM(const QuickString &filename, int fileIdx) {
if (_useMergedIntervals) {
FileRecordMergeMgr *frm = new FileRecordMergeMgr(filename);
frm->setStrandType(_desiredStrand);
frm->setMaxDistance(_maxDistance);
frm->setFileIdx(fileIdx);
return frm;
} else {
FileRecordMgr *frm = new FileRecordMgr(filename);
frm->setFileIdx(fileIdx);
return frm;
}
}
......
......@@ -98,6 +98,15 @@ public:
virtual bool getSortedInput() const {return _sortedInput; }
virtual void setSortedInput(bool val) { _sortedInput = val; }
virtual bool getSortOutput() const {return _sortOutput; }
virtual void setSortOutput(bool val) { _sortOutput = val; }
virtual bool getUseDBnameTags() const { return _reportDBnameTags; }
virtual void setUseDBnameTags(bool val) { _reportDBnameTags = val; }
virtual bool getUseDBfileNames() const { return _reportDBfileNames; }
virtual void setUseDBfileNames(bool val) { _reportDBfileNames = val; }
virtual bool getPrintHeader() const {return _printHeader; }
virtual void setPrintHeader(bool val) { _printHeader = val; }
......@@ -107,24 +116,6 @@ public:
virtual bool getUseFullBamTags() const { return _useFullBamTags; }
virtual void setUseFullBamTags(bool val) { _useFullBamTags = val; }
// //
// // MERGE METHODS
// //
// virtual bool getReportCount() const { return _reportCount; }
// virtual void setReportCount(bool val) { _reportCount = val; }
//
// virtual int getMaxDistance() const { return _maxDistance; }
// virtual void setMaxDistance(int distance) { _maxDistance = distance; }
//
// virtual bool getReportNames() const { return _reportNames; }
// virtual void setReportNames(bool val) { _reportNames = val; }
//
// virtual bool getReportScores() const { return _reportScores; }
// virtual void setReportScores(bool val) { _reportScores = val; }
//
// virtual const QuickString &getScoreOp() const { return _scoreOp; }
// virtual void setScoreOp(const QuickString &op) { _scoreOp = op; }
// METHODS FOR PROGRAMS WITH USER_SPECIFIED NUMBER
// OF OUTPUT RECORDS.
......@@ -150,7 +141,7 @@ public:
// are available.
void setColumnOpsMethods(bool val);
virtual bool hasColumnOpsMethods() const { return _hasColumnOpsMethods; }
const QuickString &getColumnOpsVal(RecordKeyList &keyList) const;
const QuickString &getColumnOpsVal(RecordKeyVector &keyList) const;
//methods applicable only to column operations.
protected:
......@@ -192,18 +183,19 @@ protected:
bool _sameStrand;
bool _diffStrand;
bool _sortedInput;
bool _sortOutput;
bool _reportDBnameTags;
bool _reportDBfileNames;
bool _printHeader;
bool _printable;
bool _explicitBedOutput;
int _queryFileIdx;
int _databaseFileIdx;
vector<int> _dbFileIdxs;
vector<QuickString> _dbNameTags;
map<int, int> _fileIdsToDbIdxs;
int _bamHeaderAndRefIdx;
int _maxNumDatabaseFields;
bool _useFullBamTags;
bool _reportCount;
bool _reportNames;
bool _reportScores;
QuickString _scoreOp;
int _numOutputRecords;
......@@ -227,7 +219,7 @@ protected:
bool isUsed(int i) const { return _argsProcessed[i]; }
bool cmdArgsValid();
bool openFiles();
virtual FileRecordMgr *getNewFRM(const QuickString &filename);
virtual FileRecordMgr *getNewFRM(const QuickString &filename, int fileIdx);
//set cmd line params and counter, i, as members so code
//is more readable (as opposed to passing all 3 everywhere).
......@@ -256,6 +248,7 @@ protected:
virtual bool handle_o();
virtual bool handle_null();
virtual bool handle_delim();
virtual bool handle_sortout();
bool parseIoBufSize(QuickString bufStr);
......
......@@ -44,6 +44,12 @@ bool ContextIntersect::parseCmdArgs(int argc, char **argv, int skipFirstArgs) {
else if (strcmp(_argv[_i], "-b") == 0) {
if (!handle_b()) return false;
}
else if (strcmp(_argv[_i], "-names") == 0) {
if (!handle_names()) return false;
}
else if (strcmp(_argv[_i], "-filenames") == 0) {
if (!handle_filenames()) return false;
}
else if (strcmp(_argv[_i], "-u") == 0) {
if (!handle_u()) return false;
}
......@@ -92,7 +98,7 @@ bool ContextIntersect::isValidState()
return false;
}
if (_queryFileIdx == -1 || _databaseFileIdx == -1) {
if (_queryFileIdx == -1 || _dbFileIdxs.size() == -0) {
_errorMsg = "\n***** ERROR: query and database files not specified. *****";
return false;
}
......@@ -113,6 +119,11 @@ bool ContextIntersect::isValidState()
return false;
}
}
if (getUseDBnameTags() && _dbNameTags.size() != _dbFileIdxs.size()) {
_errorMsg = "\n***** ERROR: Number of database name tags given does not match number of databases. *****";
return false;
}
if (getWriteOverlap()) {
if (getWriteA()) {
......@@ -149,7 +160,7 @@ bool ContextIntersect::isValidState()
if (getAnyHit() || getNoHit() || getWriteCount()) {
setPrintable(false);
}
if (_files.size() != 2 ) {
if (_files.size() < 2 ) {
return false;
}
return true;
......@@ -161,8 +172,8 @@ bool ContextIntersect::determineOutputType() {
}
//determine the maximum number of database fields.
for (int i=0; i < (int)_files.size(); i++) {
int numFields = _files[i]->getNumFields();
for (int i=0; i < getNumDatabaseFiles(); i++) {
int numFields = getDatabaseFile(i)->getNumFields();
if ( numFields > _maxNumDatabaseFields) {
_maxNumDatabaseFields = numFields;
}
......@@ -211,19 +222,46 @@ bool ContextIntersect::handle_abam()
bool ContextIntersect::handle_b()
{
if (_argc <= _i+1) {
_errorMsg = "\n***** ERROR: -b option given, but no query file specified. *****";
_errorMsg = "\n***** ERROR: -b option given, but no database file specified. *****";
return false;
}
addInputFile(_argv[_i+1]);
_databaseFileIdx = getNumInputFiles() -1;
markUsed(_i - _skipFirstArgs);
_i++;
markUsed(_i - _skipFirstArgs);
do {
addInputFile(_argv[_i+1]);
int fileId = getNumInputFiles() -1;
_dbFileIdxs.push_back(fileId);
_fileIdsToDbIdxs[fileId] = _dbFileIdxs.size() -1;
markUsed(_i - _skipFirstArgs);
_i++;
markUsed(_i - _skipFirstArgs);
} while (_argc > _i+1 && _argv[_i+1][0] != '-');
return true;
}
bool ContextIntersect::handle_names()
{
if (_argc <= _i+1) {
_errorMsg = "\n***** ERROR: -b option given, but no database names specified. *****";
return false;
}
do {
addDatabaseNameTag(_argv[_i+1]);
markUsed(_i - _skipFirstArgs);
_i++;
markUsed(_i - _skipFirstArgs);
} while (_argc > _i+1 && _argv[_i+1][0] != '-');
setUseDBnameTags(true);
return true;
}
bool ContextIntersect::handle_filenames()
{
markUsed(_i - _skipFirstArgs);
setUseDBfileNames(true);
return true;
}
bool ContextIntersect::handle_c()
{
......
......@@ -22,19 +22,22 @@ public:
//NOTE: Query and database files will only be marked as such by either the
//parseCmdArgs method, or by explicitly setting them.
FileRecordMgr *getQueryFile() { return getFile(_queryFileIdx); }
FileRecordMgr *getDatabaseFile() { return getFile(_databaseFileIdx); }
FileRecordMgr *getDatabaseFile(int idx) { return getFile(_dbFileIdxs[idx]); }
int getQueryFileIdx() const { return _queryFileIdx; }
void setQueryFileIdx(int idx) { _queryFileIdx = idx; }
int getDatabaseFileIdx() const { return _databaseFileIdx; }
void setDatabaseFileIdx(int idx) { _databaseFileIdx = idx; }
int getNumDatabaseFiles() { return (int)_dbFileIdxs.size(); }
const vector<int> &getDbFileIdxs() const { return _dbFileIdxs; }
const QuickString &getQueryFileName() const { return _files[_queryFileIdx]->getFileName(); }
const QuickString &getDatabaseFileName() const { return _files[_databaseFileIdx]->getFileName(); }
const QuickString &getDatabaseFileName(int idx) const { return _files[_dbFileIdxs[idx]]->getFileName(); }
ContextFileType getQueryFileType() const { return _files[_queryFileIdx]->getFileType(); }
ContextFileType getDatabaseFileType() const { return _files[_databaseFileIdx]->getFileType(); }
ContextFileType getDatabaseFileType(int idx) const { return _files[_dbFileIdxs[idx]]->getFileType(); }
ContextRecordType getQueryRecordType() const { return _files[_queryFileIdx]->getRecordType(); }
ContextRecordType getDatabaseRecordType() const { return _files[_databaseFileIdx]->getRecordType(); }
ContextRecordType getDatabaseRecordType(int idx) const { return _files[_dbFileIdxs[idx]]->getRecordType(); }
int getMaxNumDatabaseFields() const { return _maxNumDatabaseFields; }
void setMaxNumDatabaseFields(int val) { _maxNumDatabaseFields = val; }
int getDbIdx(int fileId) { return _fileIdsToDbIdxs.find(fileId)->second; }
void addDatabaseNameTag(const QuickString &tag) { _dbNameTags.push_back(tag); }
const QuickString &getDatabaseNameTag(int dbIdx) const { return _dbNameTags[dbIdx]; }
bool getAnyHit() const {return _anyHit; }
void setAnyHit(bool val) { _anyHit = val; }
......@@ -83,6 +86,9 @@ private:
virtual bool handle_a();
virtual bool handle_abam();
virtual bool handle_b();
virtual bool handle_names();
virtual bool handle_filenames();