Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
P
Practical-Cryptdb
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Zhaozhen
Practical-Cryptdb
Commits
e4697269
Commit
e4697269
authored
Apr 28, 2018
by
yiwenshao
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add tools/load_sql.cc
parent
10311069
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
521 additions
and
0 deletions
+521
-0
load_sql.cc
tools/load_sql.cc
+521
-0
No files found.
tools/load_sql.cc
0 → 100644
View file @
e4697269
#include <cstdlib>
#include <cstdio>
#include <string>
#include <map>
#include <iostream>
#include <fstream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <set>
#include <list>
#include <algorithm>
#include <functional>
#include <cctype>
#include <locale>
#include <unistd.h>
#include <sys/types.h>
#include <dirent.h>
#include <main/Connect.hh>
#include <main/rewrite_main.hh>
#include <main/rewrite_util.hh>
#include <main/sql_handler.hh>
#include <main/dml_handler.hh>
#include <main/ddl_handler.hh>
#include <main/metadata_tables.hh>
#include <main/macro_util.hh>
#include <main/CryptoHandlers.hh>
#include <parser/embedmysql.hh>
#include <parser/stringify.hh>
#include <parser/lex_util.hh>
#include <readline/readline.h>
#include <readline/history.h>
#include <crypto/ecjoin.hh>
#include <util/errstream.hh>
#include <util/cryptdb_log.hh>
#include <util/enum_text.hh>
#include <util/yield.hpp>
#include <sstream>
#include <unistd.h>
#include <map>
using
std
::
cout
;
using
std
::
cin
;
using
std
::
endl
;
using
std
::
string
;
std
::
map
<
SECLEVEL
,
std
::
string
>
gmp
;
std
::
map
<
onion
,
std
::
string
>
gmp2
;
static
std
::
string
embeddedDir
=
"/t/cryt/shadow"
;
//My WrapperState.
class
WrapperState
{
WrapperState
(
const
WrapperState
&
other
);
WrapperState
&
operator
=
(
const
WrapperState
&
rhs
);
KillZone
kill_zone
;
public
:
std
::
string
last_query
;
std
::
string
default_db
;
WrapperState
()
{}
~
WrapperState
()
{}
const
std
::
unique_ptr
<
QueryRewrite
>
&
getQueryRewrite
()
const
{
assert
(
this
->
qr
);
return
this
->
qr
;
}
void
setQueryRewrite
(
std
::
unique_ptr
<
QueryRewrite
>
&&
in_qr
)
{
this
->
qr
=
std
::
move
(
in_qr
);
}
void
selfKill
(
KillZone
::
Where
where
)
{
kill_zone
.
die
(
where
);
}
void
setKillZone
(
const
KillZone
&
kz
)
{
kill_zone
=
kz
;
}
std
::
unique_ptr
<
ProxyState
>
ps
;
std
::
vector
<
SchemaInfoRef
>
schema_info_refs
;
private
:
std
::
unique_ptr
<
QueryRewrite
>
qr
;
};
//global map, for each client, we have one WrapperState which contains ProxyState.
static
std
::
map
<
std
::
string
,
WrapperState
*>
clients
;
//This connection mimics the behaviour of MySQL-Proxy
Connect
*
globalConn
;
//Return values got by using directly the MySQL c Client
struct
rawReturnValue
{
std
::
vector
<
std
::
vector
<
std
::
string
>
>
rowValues
;
std
::
vector
<
std
::
string
>
fieldNames
;
std
::
vector
<
int
>
fieldTypes
;
};
//must be static, or we get "no previous declaration"
//execute the query and getthe rawReturnVale, this struct can be copied.
static
rawReturnValue
executeAndGetResultRemote
(
Connect
*
curConn
,
std
::
string
query
){
std
::
unique_ptr
<
DBResult
>
dbres
;
curConn
->
execute
(
query
,
&
dbres
);
rawReturnValue
myRaw
;
if
(
dbres
==
nullptr
||
dbres
->
n
==
NULL
){
//std::cout<<"no results"<<std::endl;
return
myRaw
;
}
int
num
=
mysql_num_rows
(
dbres
->
n
);
if
(
num
!=
0
)
std
::
cout
<<
"num of rows: "
<<
num
<<
std
::
endl
;
int
numOfFields
=
mysql_num_fields
(
dbres
->
n
);
if
(
numOfFields
!=
0
)
std
::
cout
<<
"num of fields: "
<<
numOfFields
<<
std
::
endl
;
MYSQL_FIELD
*
field
;
MYSQL_ROW
row
;
if
(
num
!=
0
){
while
(
(
row
=
mysql_fetch_row
(
dbres
->
n
))
){
unsigned
long
*
fieldLen
=
mysql_fetch_lengths
(
dbres
->
n
);
std
::
vector
<
std
::
string
>
curRow
;
for
(
int
i
=
0
;
i
<
numOfFields
;
i
++
){
if
(
i
==
0
)
{
while
(
(
field
=
mysql_fetch_field
(
dbres
->
n
))
)
{
myRaw
.
fieldNames
.
push_back
(
std
::
string
(
field
->
name
));
myRaw
.
fieldTypes
.
push_back
(
field
->
type
);
}
}
if
(
row
[
i
]
==
NULL
)
curRow
.
push_back
(
"NULL"
);
else
curRow
.
push_back
(
std
::
string
(
row
[
i
],
fieldLen
[
i
]));
}
myRaw
.
rowValues
.
push_back
(
curRow
);
}
}
return
myRaw
;
}
//print RawReturnValue for testing purposes.
static
void
printrawReturnValue
(
rawReturnValue
&
cur
)
{
int
len
=
cur
.
fieldTypes
.
size
();
if
(
len
==
0
){
//std::cout<<"zero output"<<std::endl;
return
;
}
if
(
static_cast
<
int
>
(
cur
.
fieldNames
.
size
())
!=
len
||
static_cast
<
int
>
(
cur
.
rowValues
[
0
].
size
())
!=
len
){
std
::
cout
<<
RED_BEGIN
<<
"size mismatch in printrawReturnValue"
<<
COLOR_END
<<
std
::
endl
;
return
;
}
for
(
int
i
=
0
;
i
<
len
;
i
++
){
std
::
cout
<<
cur
.
fieldNames
[
i
]
<<
":"
<<
cur
.
fieldTypes
[
i
]
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
auto
row
:
cur
.
rowValues
){
for
(
auto
rowItem
:
row
){
std
::
cout
<<
rowItem
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
//The rewrite phase of cryptdb.
/*
1. getSchemaInfo
2. rewrite => gather/write
3. fetch the executor and put it in wrapperState(std::unique_ptr<QueryRewrite> qr)
*/
static
bool
myRewrite
(
std
::
string
curQuery
,
unsigned
long
long
_thread_id
,
std
::
string
client
)
{
assert
(
0
==
mysql_thread_init
());
WrapperState
*
const
c_wrapper
=
clients
[
client
];
ProxyState
*
const
ps
=
c_wrapper
->
ps
.
get
();
assert
(
ps
);
c_wrapper
->
last_query
=
curQuery
;
//std::cout<<RED_BEGIN<<"start my rewrite"<<COLOR_END<<std::endl;
try
{
TEST_Text
(
retrieveDefaultDatabase
(
_thread_id
,
ps
->
getConn
(),
&
c_wrapper
->
default_db
),
"proxy failed to retrieve default database!"
);
const
std
::
shared_ptr
<
const
SchemaInfo
>
&
schema
=
ps
->
getSchemaInfo
();
c_wrapper
->
schema_info_refs
.
push_back
(
schema
);
std
::
unique_ptr
<
QueryRewrite
>
qr
=
std
::
unique_ptr
<
QueryRewrite
>
(
new
QueryRewrite
(
Rewriter
::
rewrite
(
curQuery
,
*
schema
.
get
(),
c_wrapper
->
default_db
,
*
ps
)));
assert
(
qr
);
c_wrapper
->
setQueryRewrite
(
std
::
move
(
qr
));
}
catch
(...){
std
::
cout
<<
"rewrite exception!!!"
<<
std
::
endl
;
exit
(
0
);
return
false
;
}
return
true
;
}
//helper function for transforming the rawReturnValue
static
Item_null
*
make_null
(
const
std
::
string
&
name
=
""
){
char
*
const
n
=
current_thd
->
strdup
(
name
.
c_str
());
return
new
Item_null
(
n
);
}
//helper function for transforming the rawReturnValue
static
std
::
vector
<
Item
*>
itemNullVector
(
unsigned
int
count
)
{
std
::
vector
<
Item
*>
out
;
for
(
unsigned
int
i
=
0
;
i
<
count
;
++
i
)
{
out
.
push_back
(
make_null
());
}
return
out
;
}
//transform rawReturnValue to ResType
static
ResType
MygetResTypeFromLuaTable
(
bool
isNULL
,
rawReturnValue
*
inRow
=
NULL
,
int
in_last_insert_id
=
0
){
std
::
vector
<
std
::
string
>
names
;
std
::
vector
<
enum_field_types
>
types
;
std
::
vector
<
std
::
vector
<
Item
*>
>
rows
;
//return NULL restype
if
(
isNULL
){
return
ResType
(
true
,
0
,
0
,
std
::
move
(
names
),
std
::
move
(
types
),
std
::
move
(
rows
));
}
else
{
for
(
auto
inNames
:
inRow
->
fieldNames
){
names
.
push_back
(
inNames
);
}
for
(
auto
inTypes
:
inRow
->
fieldTypes
){
types
.
push_back
(
static_cast
<
enum_field_types
>
(
inTypes
));
}
for
(
auto
inRows
:
inRow
->
rowValues
)
{
std
::
vector
<
Item
*>
curTempRow
=
itemNullVector
(
types
.
size
());
for
(
int
i
=
0
;
i
<
(
int
)(
inRows
.
size
());
i
++
){
curTempRow
[
i
]
=
(
MySQLFieldTypeToItem
(
types
[
i
],
inRows
[
i
])
);
}
rows
.
push_back
(
curTempRow
);
}
//uint64_t afrow = globalConn->get_affected_rows();
//std::cout<<GREEN_BEGIN<<"Affected rows: "<<afrow<<COLOR_END<<std::endl;
return
ResType
(
true
,
0
,
in_last_insert_id
,
std
::
move
(
names
),
std
::
move
(
types
),
std
::
move
(
rows
));
}
}
//printResType for testing purposes
static
void
parseResType
(
const
ResType
&
rd
)
{
// std::cout<<RED_BEGIN<<"rd.affected_rows: "<<rd.affected_rows<<COLOR_END<<std::endl;
// std::cout<<RED_BEGIN<<"rd.insert_id: "<<rd.insert_id<<COLOR_END<<std::endl;
for
(
auto
name
:
rd
.
names
){
std
::
cout
<<
name
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
for
(
auto
row
:
rd
.
rows
){
for
(
auto
item
:
row
){
std
::
cout
<<
ItemToString
(
*
item
)
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
}
//the "next" phase of cryptdb
/*
1. call function "next" in the executor
2. process three different return types.
*/
static
void
myNext
(
std
::
string
client
,
bool
isFirst
,
ResType
inRes
)
{
WrapperState
*
const
c_wrapper
=
clients
[
client
];
ProxyState
*
const
ps
=
c_wrapper
->
ps
.
get
();
assert
(
ps
);
ps
->
safeCreateEmbeddedTHD
();
const
ResType
&
res
=
inRes
;
const
std
::
unique_ptr
<
QueryRewrite
>
&
qr
=
c_wrapper
->
getQueryRewrite
();
try
{
NextParams
nparams
(
*
ps
,
c_wrapper
->
default_db
,
c_wrapper
->
last_query
);
const
auto
&
new_results
=
qr
->
executor
->
next
(
res
,
nparams
);
const
auto
&
result_type
=
new_results
.
first
;
switch
(
result_type
){
//execute the query, fetch the results, and call next again
case
AbstractQueryExecutor
:
:
ResultType
::
QUERY_COME_AGAIN
:
{
// std::cout<<RED_BEGIN<<"case one"<<COLOR_END<<std::endl;
const
auto
&
output
=
std
::
get
<
1
>
(
new_results
)
->
extract
<
std
::
pair
<
bool
,
std
::
string
>
>
();
const
auto
&
next_query
=
output
.
second
;
//here we execute the query against the remote database, and get rawReturnValue
rawReturnValue
resRemote
=
executeAndGetResultRemote
(
globalConn
,
next_query
);
//transform rawReturnValue first
const
auto
&
againGet
=
MygetResTypeFromLuaTable
(
false
,
&
resRemote
);
myNext
(
client
,
false
,
againGet
);
break
;
}
//only execute the query, without processing the retults
case
AbstractQueryExecutor
:
:
ResultType
::
QUERY_USE_RESULTS
:
{
// std::cout<<RED_BEGIN<<"case two"<<COLOR_END<<std::endl;
const
auto
&
new_query
=
std
::
get
<
1
>
(
new_results
)
->
extract
<
std
::
string
>
();
auto
resRemote
=
executeAndGetResultRemote
(
globalConn
,
new_query
);
printrawReturnValue
(
resRemote
);
break
;
}
//return the results to the client directly
case
AbstractQueryExecutor
:
:
ResultType
::
RESULTS
:
{
// std::cout<<RED_BEGIN<<"case three"<<COLOR_END<<std::endl;
const
auto
&
res
=
new_results
.
second
->
extract
<
ResType
>
();
parseResType
(
res
);
break
;
}
default
:
{
std
::
cout
<<
"case default"
<<
std
::
endl
;
}
}
}
catch
(...){
std
::
cout
<<
"next error"
<<
std
::
endl
;
exit
(
0
);
}
}
static
void
batchTogether
(
std
::
string
client
,
std
::
string
curQuery
,
unsigned
long
long
_thread_id
)
{
//the first step is to Rewrite, we abort this session if we fail here.
bool
resMyRewrite
=
myRewrite
(
curQuery
,
_thread_id
,
client
);
if
(
!
resMyRewrite
){
//std::cout<<"my rewrite error in batch"<<std::endl;
return
;
}
myNext
(
client
,
true
,
MygetResTypeFromLuaTable
(
true
));
}
static
void
processLayers
(
const
EncLayer
&
enc
){
//std::cout<<enc.serialize(enc)<<std::endl;
std
::
cout
<<
enc
.
name
()
<<
std
::
endl
;
}
static
void
processOnionMeta
(
const
OnionMeta
&
onion
){
std
::
cout
<<
GREEN_BEGIN
<<
"PRINT OnionMeta"
<<
COLOR_END
<<
std
::
endl
;
std
::
cout
<<
"onionmeta->getAnonOnionName(): "
<<
onion
.
getAnonOnionName
()
<<
std
::
endl
;
auto
&
layers
=
onion
.
getLayers
();
for
(
auto
&
slayer
:
layers
){
processLayers
(
*
(
slayer
.
get
()));
}
}
static
void
processFieldMeta
(
const
FieldMeta
&
field
){
//Process general info
if
(
field
.
getHasSalt
()){
std
::
cout
<<
"this field has salt"
<<
std
::
endl
;
}
std
::
cout
<<
"field.getFieldName(): "
<<
field
.
getFieldName
()
<<
std
::
endl
;
std
::
cout
<<
"field.getSaltName(): "
<<
field
.
getSaltName
()
<<
std
::
endl
;
std
::
cout
<<
"field.serialize(): "
<<
field
.
serialize
(
field
)
<<
std
::
endl
;
for
(
std
::
pair
<
const
OnionMetaKey
*
,
OnionMeta
*>
&
ompair
:
field
.
orderedOnionMetas
()){
processOnionMeta
(
*
ompair
.
second
);
}
//Process Onions
if
(
field
.
hasOnion
(
oDET
)){
field
.
getOnionMeta
(
oDET
);
}
if
(
field
.
hasOnion
(
oOPE
)){
field
.
getOnionMeta
(
oOPE
);
}
if
(
field
.
hasOnion
(
oAGG
)){
field
.
getOnionMeta
(
oAGG
);
}
return
;
//iterate over onions
for
(
const
std
::
pair
<
const
OnionMetaKey
,
std
::
unique_ptr
<
OnionMeta
>
>
&
onion
:
field
.
getChildren
()){
std
::
cout
<<
onion
.
second
->
getDatabaseID
()
<<
":"
<<
onion
.
first
.
getValue
()
<<
std
::
endl
;
}
}
static
void
processTableMeta
(
const
TableMeta
&
table
){
std
::
cout
<<
GREEN_BEGIN
<<
"PRINT TableMeta"
<<
COLOR_END
<<
std
::
endl
;
for
(
FieldMeta
*
cfm
:
table
.
orderedFieldMetas
()){
processFieldMeta
(
*
cfm
);
}
}
static
void
processDatabaseMeta
(
const
DatabaseMeta
&
dbm
,
std
::
string
table
=
"student1"
)
{
TableMeta
&
tbm
=
*
dbm
.
getChild
(
IdentityMetaKey
(
table
));
processTableMeta
(
tbm
);
return
;
}
static
void
processSchemaInfo
(
SchemaInfo
&
schema
,
std
::
string
db
=
"tdb"
){
const
std
::
unique_ptr
<
AES_KEY
>
&
TK
=
std
::
unique_ptr
<
AES_KEY
>
(
getKey
(
std
::
string
(
"113341234"
)));
Analysis
analysis
(
db
,
schema
,
TK
,
SECURITY_RATING
::
SENSITIVE
);
if
(
analysis
.
databaseMetaExists
(
db
)){
processDatabaseMeta
(
analysis
.
getDatabaseMeta
(
db
));
}
else
{
std
::
cout
<<
"data base not exists"
<<
std
::
endl
;
}
}
static
std
::
unique_ptr
<
SchemaInfo
>
myLoadSchemaInfo
()
{
std
::
unique_ptr
<
Connect
>
e_conn
(
Connect
::
getEmbedded
(
embeddedDir
));
std
::
unique_ptr
<
SchemaInfo
>
schema
(
new
SchemaInfo
());
std
::
function
<
DBMeta
*
(
DBMeta
*
const
)
>
loadChildren
=
[
&
loadChildren
,
&
e_conn
](
DBMeta
*
const
parent
)
{
auto
kids
=
parent
->
fetchChildren
(
e_conn
);
for
(
auto
it
:
kids
)
{
loadChildren
(
it
);
}
return
parent
;
};
//load all metadata and then store it in schema
loadChildren
(
schema
.
get
());
Analysis
analysis
(
std
::
string
(
"student"
),
*
schema
,
std
::
unique_ptr
<
AES_KEY
>
(
getKey
(
std
::
string
(
"113341234"
))),
SECURITY_RATING
::
SENSITIVE
);
return
schema
;
}
int
main
(
int
argc
,
char
**
argv
)
{
gmp
[
SECLEVEL
::
INVALID
]
=
"INVALID"
;
gmp
[
SECLEVEL
::
PLAINVAL
]
=
"PLAINVAL"
;
gmp
[
SECLEVEL
::
OPE
]
=
"OPE"
;
gmp
[
SECLEVEL
::
DETJOIN
]
=
"DETJOIN"
;
gmp
[
SECLEVEL
::
OPEFOREIGN
]
=
"OPEFOREIGN"
;
gmp
[
SECLEVEL
::
DET
]
=
"DET"
;
gmp
[
SECLEVEL
::
SEARCH
]
=
"SEARCH"
;
gmp
[
SECLEVEL
::
HOM
]
=
"HOM"
;
gmp
[
SECLEVEL
::
RND
]
=
"RND"
;
gmp2
[
oDET
]
=
"oDET"
;
gmp2
[
oOPE
]
=
"oOPE"
;
gmp2
[
oAGG
]
=
"oAGG"
;
gmp2
[
oSWP
]
=
"oSWP"
;
gmp2
[
oPLAIN
]
=
"oPLAIN"
;
gmp2
[
oBESTEFFORT
]
=
"oBESTEFFORT"
;
gmp2
[
oINVALID
]
=
"oINVALID"
;
string
targetDb
;
if
(
argc
==
2
){
targetDb
=
string
(
argv
[
1
]);
}
else
{
std
::
cout
<<
"should input targetDb"
<<
std
::
endl
;
return
0
;
}
std
::
string
client
=
"192.168.1.1:1234"
;
//one Wrapper per user.
clients
[
client
]
=
new
WrapperState
();
//Connect phase
ConnectionInfo
ci
(
"localhost"
,
"root"
,
"letmein"
,
3306
);
//const std::string master_key = "113341234";
const
std
::
string
master_key
=
"113341234"
;
char
*
buffer
;
if
((
buffer
=
getcwd
(
NULL
,
0
))
==
NULL
){
perror
(
"getcwd error"
);
}
embeddedDir
=
std
::
string
(
buffer
)
+
"/shadow"
;
SharedProxyState
*
shared_ps
=
new
SharedProxyState
(
ci
,
embeddedDir
,
master_key
,
determineSecurityRating
());
assert
(
0
==
mysql_thread_init
());
//we init embedded database here.
clients
[
client
]
->
ps
=
std
::
unique_ptr
<
ProxyState
>
(
new
ProxyState
(
*
shared_ps
));
clients
[
client
]
->
ps
->
safeCreateEmbeddedTHD
();
//Connect end!!
globalConn
=
new
Connect
(
ci
.
server
,
ci
.
user
,
ci
.
passwd
,
ci
.
port
);
std
::
string
curQuery
=
"SHOW DATABASES;"
;
//int count=0;
//std::cout<<"please input a new query:######################################################"<<std::endl;
if
(
targetDb
.
size
()
==
0
)
std
::
getline
(
std
::
cin
,
curQuery
);
else
curQuery
=
string
(
"use "
)
+
targetDb
;
unsigned
long
long
_thread_id
=
globalConn
->
get_thread_id
();
long
long
countWrapper
=
0
;
while
(
curQuery
!=
"quit"
){
if
(
curQuery
.
size
()
==
0
){
return
0
;
std
::
cout
<<
std
::
endl
;
std
::
getline
(
std
::
cin
,
curQuery
);
std
::
unique_ptr
<
SchemaInfo
>
schema
=
myLoadSchemaInfo
();
processSchemaInfo
(
*
schema
);
continue
;
}
countWrapper
++
;
batchTogether
(
client
,
curQuery
,
_thread_id
);
std
::
getline
(
std
::
cin
,
curQuery
);
if
(
countWrapper
==
3
){
cout
<<
"bingo"
<<
endl
;
countWrapper
=
0
;
}
}
return
0
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment